datapath: stt: Fix device list management.
[cascardo/ovs.git] / datapath / linux / compat / stt.c
1 /*
2  * Stateless TCP Tunnel (STT) vport.
3  *
4  * Copyright (c) 2015 Nicira, Inc.
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * as published by the Free Software Foundation; either version
9  * 2 of the License, or (at your option) any later version.
10  */
11
12 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
13 #include <asm/unaligned.h>
14
15 #include <linux/delay.h>
16 #include <linux/flex_array.h>
17 #include <linux/if.h>
18 #include <linux/if_vlan.h>
19 #include <linux/ip.h>
20 #include <linux/ipv6.h>
21 #include <linux/jhash.h>
22 #include <linux/list.h>
23 #include <linux/log2.h>
24 #include <linux/module.h>
25 #include <linux/net.h>
26 #include <linux/netfilter.h>
27 #include <linux/percpu.h>
28 #include <linux/skbuff.h>
29 #include <linux/tcp.h>
30 #include <linux/workqueue.h>
31
32 #include <net/dst_metadata.h>
33 #include <net/icmp.h>
34 #include <net/inet_ecn.h>
35 #include <net/ip.h>
36 #include <net/ip_tunnels.h>
37 #include <net/ip6_checksum.h>
38 #include <net/net_namespace.h>
39 #include <net/netns/generic.h>
40 #include <net/sock.h>
41 #include <net/stt.h>
42 #include <net/tcp.h>
43 #include <net/udp.h>
44
45 #include "gso.h"
46 #include "compat.h"
47
48 #define STT_NETDEV_VER  "0.1"
49 #define STT_DST_PORT 7471
50
51 #ifdef OVS_STT
52 #define STT_VER 0
53
54 /* @list: Per-net list of STT ports.
55  * @rcv: The callback is called on STT packet recv, STT reassembly can generate
56  * multiple packets, in this case first packet has tunnel outer header, rest
57  * of the packets are inner packet segments with no stt header.
58  * @rcv_data: user data.
59  * @sock: Fake TCP socket for the STT port.
60  */
61 struct stt_dev {
62         struct net_device       *dev;
63         struct net              *net;
64         struct list_head        next;
65         struct list_head        up_next;
66         struct socket           *sock;
67         __be16                  dst_port;
68 };
69
70 #define STT_CSUM_VERIFIED       BIT(0)
71 #define STT_CSUM_PARTIAL        BIT(1)
72 #define STT_PROTO_IPV4          BIT(2)
73 #define STT_PROTO_TCP           BIT(3)
74 #define STT_PROTO_TYPES         (STT_PROTO_IPV4 | STT_PROTO_TCP)
75
76 #define SUPPORTED_GSO_TYPES (SKB_GSO_TCPV4 | SKB_GSO_UDP | SKB_GSO_DODGY | \
77                              SKB_GSO_TCPV6)
78
79 /* The length and offset of a fragment are encoded in the sequence number.
80  * STT_SEQ_LEN_SHIFT is the left shift needed to store the length.
81  * STT_SEQ_OFFSET_MASK is the mask to extract the offset.
82  */
83 #define STT_SEQ_LEN_SHIFT 16
84 #define STT_SEQ_OFFSET_MASK (BIT(STT_SEQ_LEN_SHIFT) - 1)
85
86 /* The maximum amount of memory used to store packets waiting to be reassembled
87  * on a given CPU.  Once this threshold is exceeded we will begin freeing the
88  * least recently used fragments.
89  */
90 #define REASM_HI_THRESH (4 * 1024 * 1024)
91 /* The target for the high memory evictor.  Once we have exceeded
92  * REASM_HI_THRESH, we will continue freeing fragments until we hit
93  * this limit.
94  */
95 #define REASM_LO_THRESH (3 * 1024 * 1024)
96 /* The length of time a given packet has to be reassembled from the time the
97  * first fragment arrives.  Once this limit is exceeded it becomes available
98  * for cleaning.
99  */
100 #define FRAG_EXP_TIME (30 * HZ)
101 /* Number of hash entries.  Each entry has only a single slot to hold a packet
102  * so if there are collisions, we will drop packets.  This is allocated
103  * per-cpu and each entry consists of struct pkt_frag.
104  */
105 #define FRAG_HASH_SHIFT         8
106 #define FRAG_HASH_ENTRIES       BIT(FRAG_HASH_SHIFT)
107 #define FRAG_HASH_SEGS          ((sizeof(u32) * 8) / FRAG_HASH_SHIFT)
108
109 #define CLEAN_PERCPU_INTERVAL (30 * HZ)
110
111 struct pkt_key {
112         __be32 saddr;
113         __be32 daddr;
114         __be32 pkt_seq;
115         u32 mark;
116 };
117
118 struct pkt_frag {
119         struct sk_buff *skbs;
120         unsigned long timestamp;
121         struct list_head lru_node;
122         struct pkt_key key;
123 };
124
125 struct stt_percpu {
126         struct flex_array *frag_hash;
127         struct list_head frag_lru;
128         unsigned int frag_mem_used;
129
130         /* Protect frags table. */
131         spinlock_t lock;
132 };
133
134 struct first_frag {
135         struct sk_buff *last_skb;
136         unsigned int mem_used;
137         u16 tot_len;
138         u16 rcvd_len;
139         bool set_ecn_ce;
140 };
141
142 struct frag_skb_cb {
143         u16 offset;
144
145         /* Only valid for the first skb in the chain. */
146         struct first_frag first;
147 };
148
149 #define FRAG_CB(skb) ((struct frag_skb_cb *)(skb)->cb)
150
151 /* per-network namespace private data for this module */
152 struct stt_net {
153         struct list_head stt_list;
154         struct list_head stt_up_list;   /* Devices which are in IFF_UP state. */
155         int n_tunnels;
156 };
157
158 static int stt_net_id;
159
160 static struct stt_percpu __percpu *stt_percpu_data __read_mostly;
161 static u32 frag_hash_seed __read_mostly;
162
163 /* Protects sock-hash and refcounts. */
164 static DEFINE_MUTEX(stt_mutex);
165
166 static int n_tunnels;
167 static DEFINE_PER_CPU(u32, pkt_seq_counter);
168
169 static void clean_percpu(struct work_struct *work);
170 static DECLARE_DELAYED_WORK(clean_percpu_wq, clean_percpu);
171
172 static struct stt_dev *stt_find_up_dev(struct net *net, __be16 port)
173 {
174         struct stt_net *sn = net_generic(net, stt_net_id);
175         struct stt_dev *stt_dev;
176
177         list_for_each_entry_rcu(stt_dev, &sn->stt_up_list, up_next) {
178                 if (stt_dev->dst_port == port)
179                         return stt_dev;
180         }
181         return NULL;
182 }
183
184 static __be32 ack_seq(void)
185 {
186 #if NR_CPUS <= 65536
187         u32 pkt_seq, ack;
188
189         pkt_seq = this_cpu_read(pkt_seq_counter);
190         ack = pkt_seq << ilog2(NR_CPUS) | smp_processor_id();
191         this_cpu_inc(pkt_seq_counter);
192
193         return (__force __be32)ack;
194 #else
195 #error "Support for greater than 64k CPUs not implemented"
196 #endif
197 }
198
199 static int clear_gso(struct sk_buff *skb)
200 {
201         struct skb_shared_info *shinfo = skb_shinfo(skb);
202         int err;
203
204         if (shinfo->gso_type == 0 && shinfo->gso_size == 0 &&
205             shinfo->gso_segs == 0)
206                 return 0;
207
208         err = skb_unclone(skb, GFP_ATOMIC);
209         if (unlikely(err))
210                 return err;
211
212         shinfo = skb_shinfo(skb);
213         shinfo->gso_type = 0;
214         shinfo->gso_size = 0;
215         shinfo->gso_segs = 0;
216         return 0;
217 }
218
219 static struct sk_buff *normalize_frag_list(struct sk_buff *head,
220                                            struct sk_buff **skbp)
221 {
222         struct sk_buff *skb = *skbp;
223         struct sk_buff *last;
224
225         do {
226                 struct sk_buff *frags;
227
228                 if (skb_shared(skb)) {
229                         struct sk_buff *nskb = skb_clone(skb, GFP_ATOMIC);
230
231                         if (unlikely(!nskb))
232                                 return ERR_PTR(-ENOMEM);
233
234                         nskb->next = skb->next;
235                         consume_skb(skb);
236                         skb = nskb;
237                         *skbp = skb;
238                 }
239
240                 if (head) {
241                         head->len -= skb->len;
242                         head->data_len -= skb->len;
243                         head->truesize -= skb->truesize;
244                 }
245
246                 frags = skb_shinfo(skb)->frag_list;
247                 if (frags) {
248                         int err;
249
250                         err = skb_unclone(skb, GFP_ATOMIC);
251                         if (unlikely(err))
252                                 return ERR_PTR(err);
253
254                         last = normalize_frag_list(skb, &frags);
255                         if (IS_ERR(last))
256                                 return last;
257
258                         skb_shinfo(skb)->frag_list = NULL;
259                         last->next = skb->next;
260                         skb->next = frags;
261                 } else {
262                         last = skb;
263                 }
264
265                 skbp = &skb->next;
266         } while ((skb = skb->next));
267
268         return last;
269 }
270
271 /* Takes a linked list of skbs, which potentially contain frag_list
272  * (whose members in turn potentially contain frag_lists, etc.) and
273  * converts them into a single linear linked list.
274  */
275 static int straighten_frag_list(struct sk_buff **skbp)
276 {
277         struct sk_buff *err_skb;
278
279         err_skb = normalize_frag_list(NULL, skbp);
280         if (IS_ERR(err_skb))
281                 return PTR_ERR(err_skb);
282
283         return 0;
284 }
285
286 static void copy_skb_metadata(struct sk_buff *to, struct sk_buff *from)
287 {
288         to->protocol = from->protocol;
289         to->tstamp = from->tstamp;
290         to->priority = from->priority;
291         to->mark = from->mark;
292         to->vlan_tci = from->vlan_tci;
293 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
294         to->vlan_proto = from->vlan_proto;
295 #endif
296         skb_copy_secmark(to, from);
297 }
298
299 static void update_headers(struct sk_buff *skb, bool head,
300                                unsigned int l4_offset, unsigned int hdr_len,
301                                bool ipv4, u32 tcp_seq)
302 {
303         u16 old_len, new_len;
304         __be32 delta;
305         struct tcphdr *tcph;
306         int gso_size;
307
308         if (ipv4) {
309                 struct iphdr *iph = (struct iphdr *)(skb->data + ETH_HLEN);
310
311                 old_len = ntohs(iph->tot_len);
312                 new_len = skb->len - ETH_HLEN;
313                 iph->tot_len = htons(new_len);
314
315                 ip_send_check(iph);
316         } else {
317                 struct ipv6hdr *ip6h = (struct ipv6hdr *)(skb->data + ETH_HLEN);
318
319                 old_len = ntohs(ip6h->payload_len);
320                 new_len = skb->len - ETH_HLEN - sizeof(struct ipv6hdr);
321                 ip6h->payload_len = htons(new_len);
322         }
323
324         tcph = (struct tcphdr *)(skb->data + l4_offset);
325         if (!head) {
326                 tcph->seq = htonl(tcp_seq);
327                 tcph->cwr = 0;
328         }
329
330         if (skb->next) {
331                 tcph->fin = 0;
332                 tcph->psh = 0;
333         }
334
335         delta = htonl(~old_len + new_len);
336         tcph->check = ~csum_fold((__force __wsum)((__force u32)tcph->check +
337                                  (__force u32)delta));
338
339         gso_size = skb_shinfo(skb)->gso_size;
340         if (gso_size && skb->len - hdr_len <= gso_size)
341                 BUG_ON(clear_gso(skb));
342 }
343
344 static bool can_segment(struct sk_buff *head, bool ipv4, bool tcp, bool csum_partial)
345 {
346         /* If no offloading is in use then we don't have enough information
347          * to process the headers.
348          */
349         if (!csum_partial)
350                 goto linearize;
351
352         /* Handling UDP packets requires IP fragmentation, which means that
353          * the L4 checksum can no longer be calculated by hardware (since the
354          * fragments are in different packets.  If we have to compute the
355          * checksum it's faster just to linearize and large UDP packets are
356          * pretty uncommon anyways, so it's not worth dealing with for now.
357          */
358         if (!tcp)
359                 goto linearize;
360
361         if (ipv4) {
362                 struct iphdr *iph = (struct iphdr *)(head->data + ETH_HLEN);
363
364                 /* It's difficult to get the IP IDs exactly right here due to
365                  * varying segment sizes and potentially multiple layers of
366                  * segmentation.  IP ID isn't important when DF is set and DF
367                  * is generally set for TCP packets, so just linearize if it's
368                  * not.
369                  */
370                 if (!(iph->frag_off & htons(IP_DF)))
371                         goto linearize;
372         } else {
373                 struct ipv6hdr *ip6h = (struct ipv6hdr *)(head->data + ETH_HLEN);
374
375                 /* Jumbograms require more processing to update and we'll
376                  * probably never see them, so just linearize.
377                  */
378                 if (ip6h->payload_len == 0)
379                         goto linearize;
380         }
381         return true;
382
383 linearize:
384         return false;
385 }
386
387 static int copy_headers(struct sk_buff *head, struct sk_buff *frag,
388                             int hdr_len)
389 {
390         u16 csum_start;
391
392         if (skb_cloned(frag) || skb_headroom(frag) < hdr_len) {
393                 int extra_head = hdr_len - skb_headroom(frag);
394
395                 extra_head = extra_head > 0 ? extra_head : 0;
396                 if (unlikely(pskb_expand_head(frag, extra_head, 0,
397                                               GFP_ATOMIC)))
398                         return -ENOMEM;
399         }
400
401         memcpy(__skb_push(frag, hdr_len), head->data, hdr_len);
402
403         csum_start = head->csum_start - skb_headroom(head);
404         frag->csum_start = skb_headroom(frag) + csum_start;
405         frag->csum_offset = head->csum_offset;
406         frag->ip_summed = head->ip_summed;
407
408         skb_shinfo(frag)->gso_size = skb_shinfo(head)->gso_size;
409         skb_shinfo(frag)->gso_type = skb_shinfo(head)->gso_type;
410         skb_shinfo(frag)->gso_segs = 0;
411
412         copy_skb_metadata(frag, head);
413         return 0;
414 }
415
416 static int skb_list_segment(struct sk_buff *head, bool ipv4, int l4_offset)
417 {
418         struct sk_buff *skb;
419         struct tcphdr *tcph;
420         int seg_len;
421         int hdr_len;
422         int tcp_len;
423         u32 seq;
424
425         if (unlikely(!pskb_may_pull(head, l4_offset + sizeof(*tcph))))
426                 return -ENOMEM;
427
428         tcph = (struct tcphdr *)(head->data + l4_offset);
429         tcp_len = tcph->doff * 4;
430         hdr_len = l4_offset + tcp_len;
431
432         if (unlikely((tcp_len < sizeof(struct tcphdr)) ||
433                      (head->len < hdr_len)))
434                 return -EINVAL;
435
436         if (unlikely(!pskb_may_pull(head, hdr_len)))
437                 return -ENOMEM;
438
439         tcph = (struct tcphdr *)(head->data + l4_offset);
440         /* Update header of each segment. */
441         seq = ntohl(tcph->seq);
442         seg_len = skb_pagelen(head) - hdr_len;
443
444         skb = skb_shinfo(head)->frag_list;
445         skb_shinfo(head)->frag_list = NULL;
446         head->next = skb;
447         for (; skb; skb = skb->next) {
448                 int err;
449
450                 head->len -= skb->len;
451                 head->data_len -= skb->len;
452                 head->truesize -= skb->truesize;
453
454                 seq += seg_len;
455                 seg_len = skb->len;
456                 err = copy_headers(head, skb, hdr_len);
457                 if (err)
458                         return err;
459                 update_headers(skb, false, l4_offset, hdr_len, ipv4, seq);
460         }
461         update_headers(head, true, l4_offset, hdr_len, ipv4, 0);
462         return 0;
463 }
464
465 static int coalesce_skb(struct sk_buff **headp)
466 {
467         struct sk_buff *frag, *head, *prev;
468         int err;
469
470         err = straighten_frag_list(headp);
471         if (unlikely(err))
472                 return err;
473         head = *headp;
474
475         /* Coalesce frag list. */
476         prev = head;
477         for (frag = head->next; frag; frag = frag->next) {
478                 bool headstolen;
479                 int delta;
480
481                 if (unlikely(skb_unclone(prev, GFP_ATOMIC)))
482                         return -ENOMEM;
483
484                 if (!skb_try_coalesce(prev, frag, &headstolen, &delta)) {
485                         prev = frag;
486                         continue;
487                 }
488
489                 prev->next = frag->next;
490                 frag->len = 0;
491                 frag->data_len = 0;
492                 frag->truesize -= delta;
493                 kfree_skb_partial(frag, headstolen);
494                 frag = prev;
495         }
496
497         if (!head->next)
498                 return 0;
499
500         for (frag = head->next; frag; frag = frag->next) {
501                 head->len += frag->len;
502                 head->data_len += frag->len;
503                 head->truesize += frag->truesize;
504         }
505
506         skb_shinfo(head)->frag_list = head->next;
507         head->next = NULL;
508         return 0;
509 }
510
511 static int __try_to_segment(struct sk_buff *skb, bool csum_partial,
512                             bool ipv4, bool tcp, int l4_offset)
513 {
514         if (can_segment(skb, ipv4, tcp, csum_partial))
515                 return skb_list_segment(skb, ipv4, l4_offset);
516         else
517                 return skb_linearize(skb);
518 }
519
520 static int try_to_segment(struct sk_buff *skb)
521 {
522         struct stthdr *stth = stt_hdr(skb);
523         bool csum_partial = !!(stth->flags & STT_CSUM_PARTIAL);
524         bool ipv4 = !!(stth->flags & STT_PROTO_IPV4);
525         bool tcp = !!(stth->flags & STT_PROTO_TCP);
526         int l4_offset = stth->l4_offset;
527
528         return __try_to_segment(skb, csum_partial, ipv4, tcp, l4_offset);
529 }
530
531 static int segment_skb(struct sk_buff **headp, bool csum_partial,
532                        bool ipv4, bool tcp, int l4_offset)
533 {
534         int err;
535
536         err = coalesce_skb(headp);
537         if (err)
538                 return err;
539
540         if (skb_shinfo(*headp)->frag_list)
541                 return __try_to_segment(*headp, csum_partial,
542                                         ipv4, tcp, l4_offset);
543         return 0;
544 }
545
546 static int __push_stt_header(struct sk_buff *skb, __be64 tun_id,
547                              __be16 s_port, __be16 d_port,
548                              __be32 saddr, __be32 dst,
549                              __be16 l3_proto, u8 l4_proto,
550                              int dst_mtu)
551 {
552         int data_len = skb->len + sizeof(struct stthdr) + STT_ETH_PAD;
553         unsigned short encap_mss;
554         struct tcphdr *tcph;
555         struct stthdr *stth;
556
557         skb_push(skb, STT_HEADER_LEN);
558         skb_reset_transport_header(skb);
559         tcph = tcp_hdr(skb);
560         memset(tcph, 0, STT_HEADER_LEN);
561         stth = stt_hdr(skb);
562
563         if (skb->ip_summed == CHECKSUM_PARTIAL) {
564                 stth->flags |= STT_CSUM_PARTIAL;
565
566                 stth->l4_offset = skb->csum_start -
567                                         (skb_headroom(skb) +
568                                         STT_HEADER_LEN);
569
570                 if (l3_proto == htons(ETH_P_IP))
571                         stth->flags |= STT_PROTO_IPV4;
572
573                 if (l4_proto == IPPROTO_TCP)
574                         stth->flags |= STT_PROTO_TCP;
575
576                 stth->mss = htons(skb_shinfo(skb)->gso_size);
577         } else if (skb->ip_summed == CHECKSUM_UNNECESSARY) {
578                 stth->flags |= STT_CSUM_VERIFIED;
579         }
580
581         stth->vlan_tci = htons(skb->vlan_tci);
582         skb->vlan_tci = 0;
583         put_unaligned(tun_id, &stth->key);
584
585         tcph->source    = s_port;
586         tcph->dest      = d_port;
587         tcph->doff      = sizeof(struct tcphdr) / 4;
588         tcph->ack       = 1;
589         tcph->psh       = 1;
590         tcph->window    = htons(USHRT_MAX);
591         tcph->seq       = htonl(data_len << STT_SEQ_LEN_SHIFT);
592         tcph->ack_seq   = ack_seq();
593         tcph->check     = ~tcp_v4_check(skb->len, saddr, dst, 0);
594
595         skb->csum_start = skb_transport_header(skb) - skb->head;
596         skb->csum_offset = offsetof(struct tcphdr, check);
597         skb->ip_summed = CHECKSUM_PARTIAL;
598
599         encap_mss = dst_mtu - sizeof(struct iphdr) - sizeof(struct tcphdr);
600         if (data_len > encap_mss) {
601                 if (unlikely(skb_unclone(skb, GFP_ATOMIC)))
602                         return -EINVAL;
603
604                 skb_shinfo(skb)->gso_type = SKB_GSO_TCPV4;
605                 skb_shinfo(skb)->gso_size = encap_mss;
606                 skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(data_len, encap_mss);
607         } else {
608                 if (unlikely(clear_gso(skb)))
609                         return -EINVAL;
610         }
611         return 0;
612 }
613
614 static struct sk_buff *push_stt_header(struct sk_buff *head, __be64 tun_id,
615                                        __be16 s_port, __be16 d_port,
616                                        __be32 saddr, __be32 dst,
617                                        __be16 l3_proto, u8 l4_proto,
618                                        int dst_mtu)
619 {
620         struct sk_buff *skb;
621
622         if (skb_shinfo(head)->frag_list) {
623                 bool ipv4 = (l3_proto == htons(ETH_P_IP));
624                 bool tcp = (l4_proto == IPPROTO_TCP);
625                 bool csum_partial = (head->ip_summed == CHECKSUM_PARTIAL);
626                 int l4_offset = skb_transport_offset(head);
627
628                 /* Need to call skb_orphan() to report currect true-size.
629                  * calling skb_orphan() in this layer is odd but SKB with
630                  * frag-list should not be associated with any socket, so
631                  * skb-orphan should be no-op. */
632                 skb_orphan(head);
633                 if (unlikely(segment_skb(&head, csum_partial,
634                                          ipv4, tcp, l4_offset)))
635                         goto error;
636         }
637
638         for (skb = head; skb; skb = skb->next) {
639                 if (__push_stt_header(skb, tun_id, s_port, d_port, saddr, dst,
640                                       l3_proto, l4_proto, dst_mtu))
641                         goto error;
642         }
643
644         return head;
645 error:
646         kfree_skb_list(head);
647         return NULL;
648 }
649
650 static int stt_can_offload(struct sk_buff *skb, __be16 l3_proto, u8 l4_proto)
651 {
652         if (skb_is_gso(skb) && skb->ip_summed != CHECKSUM_PARTIAL) {
653                 int csum_offset;
654                 __sum16 *csum;
655                 int len;
656
657                 if (l4_proto == IPPROTO_TCP)
658                         csum_offset = offsetof(struct tcphdr, check);
659                 else if (l4_proto == IPPROTO_UDP)
660                         csum_offset = offsetof(struct udphdr, check);
661                 else
662                         return 0;
663
664                 len = skb->len - skb_transport_offset(skb);
665                 csum = (__sum16 *)(skb_transport_header(skb) + csum_offset);
666
667                 if (unlikely(!pskb_may_pull(skb, skb_transport_offset(skb) +
668                                                  csum_offset + sizeof(*csum))))
669                         return -EINVAL;
670
671                 if (l3_proto == htons(ETH_P_IP)) {
672                         struct iphdr *iph = ip_hdr(skb);
673
674                         *csum = ~csum_tcpudp_magic(iph->saddr, iph->daddr,
675                                                    len, l4_proto, 0);
676                 } else if (l3_proto == htons(ETH_P_IPV6)) {
677                         struct ipv6hdr *ip6h = ipv6_hdr(skb);
678
679                         *csum = ~csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
680                                                  len, l4_proto, 0);
681                 } else {
682                         return 0;
683                 }
684                 skb->csum_start = skb_transport_header(skb) - skb->head;
685                 skb->csum_offset = csum_offset;
686                 skb->ip_summed = CHECKSUM_PARTIAL;
687         }
688
689         if (skb->ip_summed == CHECKSUM_PARTIAL) {
690                 /* Assume receiver can only offload TCP/UDP over IPv4/6,
691                  * and require 802.1Q VLANs to be accelerated.
692                  */
693                 if (l3_proto != htons(ETH_P_IP) &&
694                     l3_proto != htons(ETH_P_IPV6))
695                         return 0;
696
697                 if (l4_proto != IPPROTO_TCP && l4_proto != IPPROTO_UDP)
698                         return 0;
699
700                 /* L4 offset must fit in a 1-byte field. */
701                 if (skb->csum_start - skb_headroom(skb) > 255)
702                         return 0;
703
704                 if (skb_shinfo(skb)->gso_type & ~SUPPORTED_GSO_TYPES)
705                         return 0;
706         }
707         /* Total size of encapsulated packet must fit in 16 bits. */
708         if (skb->len + STT_HEADER_LEN + sizeof(struct iphdr) > 65535)
709                 return 0;
710
711 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
712         if (skb_vlan_tag_present(skb) && skb->vlan_proto != htons(ETH_P_8021Q))
713                 return 0;
714 #endif
715         return 1;
716 }
717
718 static bool need_linearize(const struct sk_buff *skb)
719 {
720         struct skb_shared_info *shinfo = skb_shinfo(skb);
721         int i;
722
723         if (unlikely(shinfo->frag_list))
724                 return true;
725
726         /* Generally speaking we should linearize if there are paged frags.
727          * However, if all of the refcounts are 1 we know nobody else can
728          * change them from underneath us and we can skip the linearization.
729          */
730         for (i = 0; i < shinfo->nr_frags; i++)
731                 if (unlikely(page_count(skb_frag_page(&shinfo->frags[i])) > 1))
732                         return true;
733
734         return false;
735 }
736
737 static struct sk_buff *handle_offloads(struct sk_buff *skb, int min_headroom)
738 {
739         int err;
740
741 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
742         if (skb_vlan_tag_present(skb) && skb->vlan_proto != htons(ETH_P_8021Q)) {
743
744                 min_headroom += VLAN_HLEN;
745                 if (skb_headroom(skb) < min_headroom) {
746                         int head_delta = SKB_DATA_ALIGN(min_headroom -
747                                                         skb_headroom(skb) + 16);
748
749                         err = pskb_expand_head(skb, max_t(int, head_delta, 0),
750                                                0, GFP_ATOMIC);
751                         if (unlikely(err))
752                                 goto error;
753                 }
754
755                 skb = __vlan_hwaccel_push_inside(skb);
756                 if (!skb) {
757                         err = -ENOMEM;
758                         goto error;
759                 }
760         }
761 #endif
762
763         if (skb_is_gso(skb)) {
764                 struct sk_buff *nskb;
765                 char cb[sizeof(skb->cb)];
766
767                 memcpy(cb, skb->cb, sizeof(cb));
768
769                 nskb = __skb_gso_segment(skb, 0, false);
770                 if (IS_ERR(nskb)) {
771                         err = PTR_ERR(nskb);
772                         goto error;
773                 }
774
775                 consume_skb(skb);
776                 skb = nskb;
777                 while (nskb) {
778                         memcpy(nskb->cb, cb, sizeof(cb));
779                         nskb = nskb->next;
780                 }
781         } else if (skb->ip_summed == CHECKSUM_PARTIAL) {
782                 /* Pages aren't locked and could change at any time.
783                  * If this happens after we compute the checksum, the
784                  * checksum will be wrong.  We linearize now to avoid
785                  * this problem.
786                  */
787                 if (unlikely(need_linearize(skb))) {
788                         err = __skb_linearize(skb);
789                         if (unlikely(err))
790                                 goto error;
791                 }
792
793                 err = skb_checksum_help(skb);
794                 if (unlikely(err))
795                         goto error;
796         }
797         skb->ip_summed = CHECKSUM_NONE;
798
799         return skb;
800 error:
801         kfree_skb(skb);
802         return ERR_PTR(err);
803 }
804
805 static int skb_list_xmit(struct rtable *rt, struct sk_buff *skb, __be32 src,
806                          __be32 dst, __u8 tos, __u8 ttl, __be16 df)
807 {
808         int len = 0;
809
810         while (skb) {
811                 struct sk_buff *next = skb->next;
812
813                 if (next)
814                         dst_clone(&rt->dst);
815
816                 skb->next = NULL;
817                 len += iptunnel_xmit(NULL, rt, skb, src, dst, IPPROTO_TCP,
818                                      tos, ttl, df, false);
819
820                 skb = next;
821         }
822         return len;
823 }
824
825 static u8 parse_ipv6_l4_proto(struct sk_buff *skb)
826 {
827         unsigned int nh_ofs = skb_network_offset(skb);
828         int payload_ofs;
829         struct ipv6hdr *nh;
830         uint8_t nexthdr;
831         __be16 frag_off;
832
833         if (unlikely(!pskb_may_pull(skb, nh_ofs + sizeof(struct ipv6hdr))))
834                 return 0;
835
836         nh = ipv6_hdr(skb);
837         nexthdr = nh->nexthdr;
838         payload_ofs = (u8 *)(nh + 1) - skb->data;
839
840         payload_ofs = ipv6_skip_exthdr(skb, payload_ofs, &nexthdr, &frag_off);
841         if (unlikely(payload_ofs < 0))
842                 return 0;
843
844         return nexthdr;
845 }
846
847 static u8 skb_get_l4_proto(struct sk_buff *skb, __be16 l3_proto)
848 {
849         if (l3_proto == htons(ETH_P_IP)) {
850                 unsigned int nh_ofs = skb_network_offset(skb);
851
852                 if (unlikely(!pskb_may_pull(skb, nh_ofs + sizeof(struct iphdr))))
853                         return 0;
854
855                 return ip_hdr(skb)->protocol;
856         } else if (l3_proto == htons(ETH_P_IPV6)) {
857                 return parse_ipv6_l4_proto(skb);
858         }
859         return 0;
860 }
861
862 static int stt_xmit_skb(struct sk_buff *skb, struct rtable *rt,
863                  __be32 src, __be32 dst, __u8 tos,
864                  __u8 ttl, __be16 df, __be16 src_port, __be16 dst_port,
865                  __be64 tun_id)
866 {
867         struct ethhdr *eh = eth_hdr(skb);
868         int ret = 0, min_headroom;
869         __be16 inner_l3_proto;
870          u8 inner_l4_proto;
871
872         inner_l3_proto = eh->h_proto;
873         inner_l4_proto = skb_get_l4_proto(skb, inner_l3_proto);
874
875         min_headroom = LL_RESERVED_SPACE(rt->dst.dev) + rt->dst.header_len
876                         + STT_HEADER_LEN + sizeof(struct iphdr);
877
878         if (skb_headroom(skb) < min_headroom || skb_header_cloned(skb)) {
879                 int head_delta = SKB_DATA_ALIGN(min_headroom -
880                                                 skb_headroom(skb) +
881                                                 16);
882
883                 ret = pskb_expand_head(skb, max_t(int, head_delta, 0),
884                                        0, GFP_ATOMIC);
885                 if (unlikely(ret))
886                         goto err_free_rt;
887         }
888
889         ret = stt_can_offload(skb, inner_l3_proto, inner_l4_proto);
890         if (ret < 0)
891                 goto err_free_rt;
892         if (!ret) {
893                 skb = handle_offloads(skb, min_headroom);
894                 if (IS_ERR(skb)) {
895                         ret = PTR_ERR(skb);
896                         skb = NULL;
897                         goto err_free_rt;
898                 }
899         }
900
901         ret = 0;
902         while (skb) {
903                 struct sk_buff *next_skb = skb->next;
904
905                 skb->next = NULL;
906
907                 if (next_skb)
908                         dst_clone(&rt->dst);
909
910                 /* Push STT and TCP header. */
911                 skb = push_stt_header(skb, tun_id, src_port, dst_port, src,
912                                       dst, inner_l3_proto, inner_l4_proto,
913                                       dst_mtu(&rt->dst));
914                 if (unlikely(!skb)) {
915                         ip_rt_put(rt);
916                         goto next;
917                 }
918
919                 /* Push IP header. */
920                 ret += skb_list_xmit(rt, skb, src, dst, tos, ttl, df);
921
922 next:
923                 skb = next_skb;
924         }
925
926         return ret;
927
928 err_free_rt:
929         ip_rt_put(rt);
930         kfree_skb(skb);
931         return ret;
932 }
933
934 netdev_tx_t ovs_stt_xmit(struct sk_buff *skb)
935 {
936         struct net_device *dev = skb->dev;
937         struct stt_dev *stt_dev = netdev_priv(dev);
938         struct net *net = stt_dev->net;
939         __be16 dport = stt_dev->dst_port;
940         struct ip_tunnel_key *tun_key;
941         struct ip_tunnel_info *tun_info;
942         struct rtable *rt;
943         struct flowi4 fl;
944         __be16 sport;
945         __be16 df;
946         int err;
947
948         tun_info = skb_tunnel_info(skb);
949         if (unlikely(!tun_info)) {
950                 err = -EINVAL;
951                 goto error;
952         }
953
954         tun_key = &tun_info->key;
955
956         /* Route lookup */
957         memset(&fl, 0, sizeof(fl));
958         fl.daddr = tun_key->u.ipv4.dst;
959         fl.saddr = tun_key->u.ipv4.src;
960         fl.flowi4_tos = RT_TOS(tun_key->tos);
961         fl.flowi4_mark = skb->mark;
962         fl.flowi4_proto = IPPROTO_TCP;
963         rt = ip_route_output_key(net, &fl);
964         if (IS_ERR(rt)) {
965                 err = PTR_ERR(rt);
966                 goto error;
967         }
968
969         df = tun_key->tun_flags & TUNNEL_DONT_FRAGMENT ? htons(IP_DF) : 0;
970         sport = udp_flow_src_port(net, skb, 1, USHRT_MAX, true);
971         skb->ignore_df = 1;
972
973         err = stt_xmit_skb(skb, rt, fl.saddr, tun_key->u.ipv4.dst,
974                             tun_key->tos, tun_key->ttl,
975                             df, sport, dport, tun_key->tun_id);
976         iptunnel_xmit_stats(err, &dev->stats, (struct pcpu_sw_netstats __percpu *)dev->tstats);
977         return NETDEV_TX_OK;
978 error:
979         kfree_skb(skb);
980         dev->stats.tx_errors++;
981         return NETDEV_TX_OK;
982 }
983 EXPORT_SYMBOL(ovs_stt_xmit);
984
985 static void free_frag(struct stt_percpu *stt_percpu,
986                       struct pkt_frag *frag)
987 {
988         stt_percpu->frag_mem_used -= FRAG_CB(frag->skbs)->first.mem_used;
989         kfree_skb_list(frag->skbs);
990         list_del(&frag->lru_node);
991         frag->skbs = NULL;
992 }
993
994 static void evict_frags(struct stt_percpu *stt_percpu)
995 {
996         while (!list_empty(&stt_percpu->frag_lru) &&
997                stt_percpu->frag_mem_used > REASM_LO_THRESH) {
998                 struct pkt_frag *frag;
999
1000                 frag = list_first_entry(&stt_percpu->frag_lru,
1001                                         struct pkt_frag,
1002                                         lru_node);
1003                 free_frag(stt_percpu, frag);
1004         }
1005 }
1006
1007 static bool pkt_key_match(struct net *net,
1008                           const struct pkt_frag *a, const struct pkt_key *b)
1009 {
1010         return a->key.saddr == b->saddr && a->key.daddr == b->daddr &&
1011                a->key.pkt_seq == b->pkt_seq && a->key.mark == b->mark &&
1012                net_eq(dev_net(a->skbs->dev), net);
1013 }
1014
1015 static u32 pkt_key_hash(const struct net *net, const struct pkt_key *key)
1016 {
1017         u32 initval = frag_hash_seed ^ (u32)(unsigned long)net ^ key->mark;
1018
1019         return jhash_3words((__force u32)key->saddr, (__force u32)key->daddr,
1020                             (__force u32)key->pkt_seq, initval);
1021 }
1022
1023 static struct pkt_frag *lookup_frag(struct net *net,
1024                                     struct stt_percpu *stt_percpu,
1025                                     const struct pkt_key *key, u32 hash)
1026 {
1027         struct pkt_frag *frag, *victim_frag = NULL;
1028         int i;
1029
1030         for (i = 0; i < FRAG_HASH_SEGS; i++) {
1031                 frag = flex_array_get(stt_percpu->frag_hash,
1032                                       hash & (FRAG_HASH_ENTRIES - 1));
1033
1034                 if (frag->skbs &&
1035                     time_before(jiffies, frag->timestamp + FRAG_EXP_TIME) &&
1036                     pkt_key_match(net, frag, key))
1037                         return frag;
1038
1039                 if (!victim_frag ||
1040                     (victim_frag->skbs &&
1041                      (!frag->skbs ||
1042                       time_before(frag->timestamp, victim_frag->timestamp))))
1043                         victim_frag = frag;
1044
1045                 hash >>= FRAG_HASH_SHIFT;
1046         }
1047
1048         if (victim_frag->skbs)
1049                 free_frag(stt_percpu, victim_frag);
1050
1051         return victim_frag;
1052 }
1053
1054 static struct sk_buff *reassemble(struct sk_buff *skb)
1055 {
1056         struct iphdr *iph = ip_hdr(skb);
1057         struct tcphdr *tcph = tcp_hdr(skb);
1058         u32 seq = ntohl(tcph->seq);
1059         struct stt_percpu *stt_percpu;
1060         struct sk_buff *last_skb;
1061         struct pkt_frag *frag;
1062         struct pkt_key key;
1063         int tot_len;
1064         u32 hash;
1065
1066         tot_len = seq >> STT_SEQ_LEN_SHIFT;
1067         FRAG_CB(skb)->offset = seq & STT_SEQ_OFFSET_MASK;
1068
1069         if (unlikely(skb->len == 0))
1070                 goto out_free;
1071
1072         if (unlikely(FRAG_CB(skb)->offset + skb->len > tot_len))
1073                 goto out_free;
1074
1075         if (tot_len == skb->len)
1076                 goto out;
1077
1078         key.saddr = iph->saddr;
1079         key.daddr = iph->daddr;
1080         key.pkt_seq = tcph->ack_seq;
1081         key.mark = skb->mark;
1082         hash = pkt_key_hash(dev_net(skb->dev), &key);
1083
1084         stt_percpu = per_cpu_ptr(stt_percpu_data, smp_processor_id());
1085
1086         spin_lock(&stt_percpu->lock);
1087
1088         if (unlikely(stt_percpu->frag_mem_used + skb->truesize > REASM_HI_THRESH))
1089                 evict_frags(stt_percpu);
1090
1091         frag = lookup_frag(dev_net(skb->dev), stt_percpu, &key, hash);
1092         if (!frag->skbs) {
1093                 frag->skbs = skb;
1094                 frag->key = key;
1095                 frag->timestamp = jiffies;
1096                 FRAG_CB(skb)->first.last_skb = skb;
1097                 FRAG_CB(skb)->first.mem_used = skb->truesize;
1098                 FRAG_CB(skb)->first.tot_len = tot_len;
1099                 FRAG_CB(skb)->first.rcvd_len = skb->len;
1100                 FRAG_CB(skb)->first.set_ecn_ce = false;
1101                 list_add_tail(&frag->lru_node, &stt_percpu->frag_lru);
1102                 stt_percpu->frag_mem_used += skb->truesize;
1103
1104                 skb = NULL;
1105                 goto unlock;
1106         }
1107
1108         /* Optimize for the common case where fragments are received in-order
1109          * and not overlapping.
1110          */
1111         last_skb = FRAG_CB(frag->skbs)->first.last_skb;
1112         if (likely(FRAG_CB(last_skb)->offset + last_skb->len ==
1113                    FRAG_CB(skb)->offset)) {
1114                 last_skb->next = skb;
1115                 FRAG_CB(frag->skbs)->first.last_skb = skb;
1116         } else {
1117                 struct sk_buff *prev = NULL, *next;
1118
1119                 for (next = frag->skbs; next; next = next->next) {
1120                         if (FRAG_CB(next)->offset >= FRAG_CB(skb)->offset)
1121                                 break;
1122                         prev = next;
1123                 }
1124
1125                 /* Overlapping fragments aren't allowed.  We shouldn't start
1126                  * before the end of the previous fragment.
1127                  */
1128                 if (prev &&
1129                     FRAG_CB(prev)->offset + prev->len > FRAG_CB(skb)->offset)
1130                         goto unlock_free;
1131
1132                 /* We also shouldn't end after the beginning of the next
1133                  * fragment.
1134                  */
1135                 if (next &&
1136                     FRAG_CB(skb)->offset + skb->len > FRAG_CB(next)->offset)
1137                         goto unlock_free;
1138
1139                 if (prev) {
1140                         prev->next = skb;
1141                 } else {
1142                         FRAG_CB(skb)->first = FRAG_CB(frag->skbs)->first;
1143                         frag->skbs = skb;
1144                 }
1145
1146                 if (next)
1147                         skb->next = next;
1148                 else
1149                         FRAG_CB(frag->skbs)->first.last_skb = skb;
1150         }
1151
1152         FRAG_CB(frag->skbs)->first.set_ecn_ce |= INET_ECN_is_ce(iph->tos);
1153         FRAG_CB(frag->skbs)->first.rcvd_len += skb->len;
1154         FRAG_CB(frag->skbs)->first.mem_used += skb->truesize;
1155         stt_percpu->frag_mem_used += skb->truesize;
1156
1157         if (FRAG_CB(frag->skbs)->first.tot_len ==
1158             FRAG_CB(frag->skbs)->first.rcvd_len) {
1159                 struct sk_buff *frag_head = frag->skbs;
1160
1161                 frag_head->tstamp = skb->tstamp;
1162                 if (FRAG_CB(frag_head)->first.set_ecn_ce)
1163                         INET_ECN_set_ce(frag_head);
1164
1165                 list_del(&frag->lru_node);
1166                 stt_percpu->frag_mem_used -= FRAG_CB(frag_head)->first.mem_used;
1167                 frag->skbs = NULL;
1168                 skb = frag_head;
1169         } else {
1170                 list_move_tail(&frag->lru_node, &stt_percpu->frag_lru);
1171                 skb = NULL;
1172         }
1173
1174         goto unlock;
1175
1176 unlock_free:
1177         kfree_skb(skb);
1178         skb = NULL;
1179 unlock:
1180         spin_unlock(&stt_percpu->lock);
1181         return skb;
1182 out_free:
1183         kfree_skb(skb);
1184         skb = NULL;
1185 out:
1186         return skb;
1187 }
1188
1189 static bool validate_checksum(struct sk_buff *skb)
1190 {
1191         struct iphdr *iph = ip_hdr(skb);
1192
1193         if (skb_csum_unnecessary(skb))
1194                 return true;
1195
1196         if (skb->ip_summed == CHECKSUM_COMPLETE &&
1197             !tcp_v4_check(skb->len, iph->saddr, iph->daddr, skb->csum))
1198                 return true;
1199
1200         skb->csum = csum_tcpudp_nofold(iph->saddr, iph->daddr, skb->len,
1201                                        IPPROTO_TCP, 0);
1202
1203         return __tcp_checksum_complete(skb) == 0;
1204 }
1205
1206 static bool set_offloads(struct sk_buff *skb)
1207 {
1208         struct stthdr *stth = stt_hdr(skb);
1209         unsigned short gso_type;
1210         int l3_header_size;
1211         int l4_header_size;
1212         u16 csum_offset;
1213         u8 proto_type;
1214
1215         if (stth->vlan_tci)
1216                 __vlan_hwaccel_put_tag(skb, htons(ETH_P_8021Q),
1217                                        ntohs(stth->vlan_tci));
1218
1219         if (!(stth->flags & STT_CSUM_PARTIAL)) {
1220                 if (stth->flags & STT_CSUM_VERIFIED)
1221                         skb->ip_summed = CHECKSUM_UNNECESSARY;
1222                 else
1223                         skb->ip_summed = CHECKSUM_NONE;
1224
1225                 return clear_gso(skb) == 0;
1226         }
1227
1228         proto_type = stth->flags & STT_PROTO_TYPES;
1229
1230         switch (proto_type) {
1231         case (STT_PROTO_IPV4 | STT_PROTO_TCP):
1232                 /* TCP/IPv4 */
1233                 csum_offset = offsetof(struct tcphdr, check);
1234                 gso_type = SKB_GSO_TCPV4;
1235                 l3_header_size = sizeof(struct iphdr);
1236                 l4_header_size = sizeof(struct tcphdr);
1237                 skb->protocol = htons(ETH_P_IP);
1238                 break;
1239         case STT_PROTO_TCP:
1240                 /* TCP/IPv6 */
1241                 csum_offset = offsetof(struct tcphdr, check);
1242                 gso_type = SKB_GSO_TCPV6;
1243                 l3_header_size = sizeof(struct ipv6hdr);
1244                 l4_header_size = sizeof(struct tcphdr);
1245                 skb->protocol = htons(ETH_P_IPV6);
1246                 break;
1247         case STT_PROTO_IPV4:
1248                 /* UDP/IPv4 */
1249                 csum_offset = offsetof(struct udphdr, check);
1250                 gso_type = SKB_GSO_UDP;
1251                 l3_header_size = sizeof(struct iphdr);
1252                 l4_header_size = sizeof(struct udphdr);
1253                 skb->protocol = htons(ETH_P_IP);
1254                 break;
1255         default:
1256                 /* UDP/IPv6 */
1257                 csum_offset = offsetof(struct udphdr, check);
1258                 gso_type = SKB_GSO_UDP;
1259                 l3_header_size = sizeof(struct ipv6hdr);
1260                 l4_header_size = sizeof(struct udphdr);
1261                 skb->protocol = htons(ETH_P_IPV6);
1262         }
1263
1264         if (unlikely(stth->l4_offset < ETH_HLEN + l3_header_size))
1265                 return false;
1266
1267         if (unlikely(!pskb_may_pull(skb, stth->l4_offset + l4_header_size)))
1268                 return false;
1269
1270         stth = stt_hdr(skb);
1271
1272         skb->csum_start = skb_headroom(skb) + stth->l4_offset;
1273         skb->csum_offset = csum_offset;
1274         skb->ip_summed = CHECKSUM_PARTIAL;
1275
1276         if (stth->mss) {
1277                 if (unlikely(skb_unclone(skb, GFP_ATOMIC)))
1278                         return false;
1279
1280                 skb_shinfo(skb)->gso_type = gso_type | SKB_GSO_DODGY;
1281                 skb_shinfo(skb)->gso_size = ntohs(stth->mss);
1282                 skb_shinfo(skb)->gso_segs = 0;
1283         } else {
1284                 if (unlikely(clear_gso(skb)))
1285                         return false;
1286         }
1287
1288         return true;
1289 }
1290
1291 static void rcv_list(struct net_device *dev, struct sk_buff *skb,
1292                      struct metadata_dst *tun_dst)
1293 {
1294         struct sk_buff *next;
1295
1296         do {
1297                 next = skb->next;
1298                 skb->next = NULL;
1299                 if (next) {
1300                         ovs_dst_hold((struct dst_entry *)tun_dst);
1301                         ovs_skb_dst_set(next, (struct dst_entry *)tun_dst);
1302                 }
1303                 ovs_ip_tunnel_rcv(dev, skb, tun_dst);
1304         } while ((skb = next));
1305 }
1306
1307 #ifndef HAVE_METADATA_DST
1308 static int __stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
1309 {
1310         struct metadata_dst tun_dst;
1311
1312         ovs_ip_tun_rx_dst(&tun_dst.u.tun_info, skb, TUNNEL_KEY | TUNNEL_CSUM,
1313                           get_unaligned(&stt_hdr(skb)->key), 0);
1314         tun_dst.u.tun_info.key.tp_src = tcp_hdr(skb)->source;
1315         tun_dst.u.tun_info.key.tp_dst = tcp_hdr(skb)->dest;
1316
1317         rcv_list(stt_dev->dev, skb, &tun_dst);
1318         return 0;
1319 }
1320 #else
1321 static int __stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
1322 {
1323         struct metadata_dst *tun_dst;
1324         __be16 flags;
1325         __be64 tun_id;
1326
1327         flags = TUNNEL_KEY | TUNNEL_CSUM;
1328         tun_id = get_unaligned(&stt_hdr(skb)->key);
1329         tun_dst = ip_tun_rx_dst(skb, flags, tun_id, 0);
1330         if (!tun_dst)
1331                 return -ENOMEM;
1332         tun_dst->u.tun_info.key.tp_src = tcp_hdr(skb)->source;
1333         tun_dst->u.tun_info.key.tp_dst = tcp_hdr(skb)->dest;
1334
1335         rcv_list(stt_dev->dev, skb, tun_dst);
1336         return 0;
1337 }
1338 #endif
1339
1340 static void stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
1341 {
1342         int err;
1343
1344         if (unlikely(!validate_checksum(skb)))
1345                 goto drop;
1346
1347         skb = reassemble(skb);
1348         if (!skb)
1349                 return;
1350
1351         if (skb->next && coalesce_skb(&skb))
1352                 goto drop;
1353
1354         err = iptunnel_pull_header(skb,
1355                                    sizeof(struct stthdr) + STT_ETH_PAD,
1356                                    htons(ETH_P_TEB));
1357         if (unlikely(err))
1358                 goto drop;
1359
1360         if (unlikely(stt_hdr(skb)->version != 0))
1361                 goto drop;
1362
1363         if (unlikely(!set_offloads(skb)))
1364                 goto drop;
1365
1366         if (skb_shinfo(skb)->frag_list && try_to_segment(skb))
1367                 goto drop;
1368
1369         err = __stt_rcv(stt_dev, skb);
1370         if (err)
1371                 goto drop;
1372         return;
1373 drop:
1374         /* Consume bad packet */
1375         kfree_skb_list(skb);
1376         stt_dev->dev->stats.rx_errors++;
1377 }
1378
1379 static void tcp_sock_release(struct socket *sock)
1380 {
1381         kernel_sock_shutdown(sock, SHUT_RDWR);
1382         sock_release(sock);
1383 }
1384
1385 static int tcp_sock_create4(struct net *net, __be16 port,
1386                             struct socket **sockp)
1387 {
1388         struct sockaddr_in tcp_addr;
1389         struct socket *sock = NULL;
1390         int err;
1391
1392         err = sock_create_kern(net, AF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
1393         if (err < 0)
1394                 goto error;
1395
1396         memset(&tcp_addr, 0, sizeof(tcp_addr));
1397         tcp_addr.sin_family = AF_INET;
1398         tcp_addr.sin_addr.s_addr = htonl(INADDR_ANY);
1399         tcp_addr.sin_port = port;
1400         err = kernel_bind(sock, (struct sockaddr *)&tcp_addr,
1401                           sizeof(tcp_addr));
1402         if (err < 0)
1403                 goto error;
1404
1405         *sockp = sock;
1406         return 0;
1407
1408 error:
1409         if (sock)
1410                 tcp_sock_release(sock);
1411         *sockp = NULL;
1412         return err;
1413 }
1414
1415 static void schedule_clean_percpu(void)
1416 {
1417         schedule_delayed_work(&clean_percpu_wq, CLEAN_PERCPU_INTERVAL);
1418 }
1419
1420 static void clean_percpu(struct work_struct *work)
1421 {
1422         int i;
1423
1424         for_each_possible_cpu(i) {
1425                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1426                 int j;
1427
1428                 for (j = 0; j < FRAG_HASH_ENTRIES; j++) {
1429                         struct pkt_frag *frag;
1430
1431                         frag = flex_array_get(stt_percpu->frag_hash, j);
1432                         if (!frag->skbs ||
1433                             time_before(jiffies, frag->timestamp + FRAG_EXP_TIME))
1434                                 continue;
1435
1436                         spin_lock_bh(&stt_percpu->lock);
1437
1438                         if (frag->skbs &&
1439                             time_after(jiffies, frag->timestamp + FRAG_EXP_TIME))
1440                                 free_frag(stt_percpu, frag);
1441
1442                         spin_unlock_bh(&stt_percpu->lock);
1443                 }
1444         }
1445         schedule_clean_percpu();
1446 }
1447
1448 #ifdef HAVE_NF_HOOKFN_ARG_OPS
1449 #define FIRST_PARAM const struct nf_hook_ops *ops
1450 #else
1451 #define FIRST_PARAM unsigned int hooknum
1452 #endif
1453
1454 #ifdef HAVE_NF_HOOK_STATE
1455 #if RHEL_RELEASE_CODE > RHEL_RELEASE_VERSION(7,0)
1456 /* RHEL nfhook hacks. */
1457 #ifndef __GENKSYMS__
1458 #define LAST_PARAM const struct net_device *in, const struct net_device *out, \
1459                    const struct nf_hook_state *state
1460 #else
1461 #define LAST_PARAM const struct net_device *in, const struct net_device *out, \
1462                    int (*okfn)(struct sk_buff *)
1463 #endif
1464 #else
1465 #define LAST_PARAM const struct nf_hook_state *state
1466 #endif
1467 #else
1468 #define LAST_PARAM const struct net_device *in, const struct net_device *out, \
1469                    int (*okfn)(struct sk_buff *)
1470 #endif
1471
1472 static unsigned int nf_ip_hook(FIRST_PARAM, struct sk_buff *skb, LAST_PARAM)
1473 {
1474         struct stt_dev *stt_dev;
1475         int ip_hdr_len;
1476
1477         if (ip_hdr(skb)->protocol != IPPROTO_TCP)
1478                 return NF_ACCEPT;
1479
1480         ip_hdr_len = ip_hdrlen(skb);
1481         if (unlikely(!pskb_may_pull(skb, ip_hdr_len + sizeof(struct tcphdr))))
1482                 return NF_ACCEPT;
1483
1484         skb_set_transport_header(skb, ip_hdr_len);
1485
1486         stt_dev = stt_find_up_dev(dev_net(skb->dev), tcp_hdr(skb)->dest);
1487         if (!stt_dev)
1488                 return NF_ACCEPT;
1489
1490         __skb_pull(skb, ip_hdr_len + sizeof(struct tcphdr));
1491         stt_rcv(stt_dev, skb);
1492         return NF_STOLEN;
1493 }
1494
1495 static struct nf_hook_ops nf_hook_ops __read_mostly = {
1496         .hook           = nf_ip_hook,
1497         .owner          = THIS_MODULE,
1498         .pf             = NFPROTO_IPV4,
1499         .hooknum        = NF_INET_LOCAL_IN,
1500         .priority       = INT_MAX,
1501 };
1502
1503 static int stt_start(struct net *net)
1504 {
1505         struct stt_net *sn = net_generic(net, stt_net_id);
1506         int err;
1507         int i;
1508
1509         if (n_tunnels) {
1510                 n_tunnels++;
1511                 return 0;
1512         }
1513         get_random_bytes(&frag_hash_seed, sizeof(u32));
1514
1515         stt_percpu_data = alloc_percpu(struct stt_percpu);
1516         if (!stt_percpu_data) {
1517                 err = -ENOMEM;
1518                 goto error;
1519         }
1520
1521         for_each_possible_cpu(i) {
1522                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1523                 struct flex_array *frag_hash;
1524
1525                 spin_lock_init(&stt_percpu->lock);
1526                 INIT_LIST_HEAD(&stt_percpu->frag_lru);
1527                 get_random_bytes(&per_cpu(pkt_seq_counter, i), sizeof(u32));
1528
1529                 frag_hash = flex_array_alloc(sizeof(struct pkt_frag),
1530                                              FRAG_HASH_ENTRIES,
1531                                              GFP_KERNEL | __GFP_ZERO);
1532                 if (!frag_hash) {
1533                         err = -ENOMEM;
1534                         goto free_percpu;
1535                 }
1536                 stt_percpu->frag_hash = frag_hash;
1537
1538                 err = flex_array_prealloc(stt_percpu->frag_hash, 0,
1539                                           FRAG_HASH_ENTRIES,
1540                                           GFP_KERNEL | __GFP_ZERO);
1541                 if (err)
1542                         goto free_percpu;
1543         }
1544         schedule_clean_percpu();
1545         n_tunnels++;
1546
1547         if (sn->n_tunnels) {
1548                 sn->n_tunnels++;
1549                 return 0;
1550         }
1551 #ifdef HAVE_NF_REGISTER_NET_HOOK
1552         /* On kernel which support per net nf-hook, nf_register_hook() takes
1553          * rtnl-lock, which results in dead lock in stt-dev-create. Therefore
1554          * use this new API.
1555          */
1556         err = nf_register_net_hook(net, &nf_hook_ops);
1557 #else
1558         err = nf_register_hook(&nf_hook_ops);
1559 #endif
1560         if (err)
1561                 goto dec_n_tunnel;
1562         sn->n_tunnels++;
1563         return 0;
1564
1565 dec_n_tunnel:
1566         n_tunnels--;
1567 free_percpu:
1568         for_each_possible_cpu(i) {
1569                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1570
1571                 if (stt_percpu->frag_hash)
1572                         flex_array_free(stt_percpu->frag_hash);
1573         }
1574
1575         free_percpu(stt_percpu_data);
1576
1577 error:
1578         return err;
1579 }
1580
1581 static void stt_cleanup(struct net *net)
1582 {
1583         struct stt_net *sn = net_generic(net, stt_net_id);
1584         int i;
1585
1586         sn->n_tunnels--;
1587         if (sn->n_tunnels)
1588                 goto out;
1589 #ifdef HAVE_NF_REGISTER_NET_HOOK
1590         nf_unregister_net_hook(net, &nf_hook_ops);
1591 #else
1592         nf_unregister_hook(&nf_hook_ops);
1593 #endif
1594
1595 out:
1596         n_tunnels--;
1597         if (n_tunnels)
1598                 return;
1599
1600         cancel_delayed_work_sync(&clean_percpu_wq);
1601         for_each_possible_cpu(i) {
1602                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1603                 int j;
1604
1605                 for (j = 0; j < FRAG_HASH_ENTRIES; j++) {
1606                         struct pkt_frag *frag;
1607
1608                         frag = flex_array_get(stt_percpu->frag_hash, j);
1609                         kfree_skb_list(frag->skbs);
1610                 }
1611
1612                 flex_array_free(stt_percpu->frag_hash);
1613         }
1614
1615         free_percpu(stt_percpu_data);
1616 }
1617
1618 static netdev_tx_t stt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1619 {
1620 #ifdef HAVE_METADATA_DST
1621         return ovs_stt_xmit(skb);
1622 #else
1623         /* Drop All packets coming from networking stack. OVS-CB is
1624          * not initialized for these packets.
1625          */
1626         dev_kfree_skb(skb);
1627         dev->stats.tx_dropped++;
1628         return NETDEV_TX_OK;
1629 #endif
1630 }
1631
1632 /* Setup stats when device is created */
1633 static int stt_init(struct net_device *dev)
1634 {
1635         dev->tstats = (typeof(dev->tstats)) netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
1636         if (!dev->tstats)
1637                 return -ENOMEM;
1638
1639         return 0;
1640 }
1641
1642 static void stt_uninit(struct net_device *dev)
1643 {
1644         free_percpu(dev->tstats);
1645 }
1646
1647 static int stt_open(struct net_device *dev)
1648 {
1649         struct stt_dev *stt = netdev_priv(dev);
1650         struct net *net = stt->net;
1651         struct stt_net *sn = net_generic(net, stt_net_id);
1652         int err;
1653
1654         err = stt_start(net);
1655         if (err)
1656                 return err;
1657
1658         err = tcp_sock_create4(net, stt->dst_port, &stt->sock);
1659         if (err)
1660                 return err;
1661         list_add_rcu(&stt->up_next, &sn->stt_up_list);
1662         return 0;
1663 }
1664
1665 static int stt_stop(struct net_device *dev)
1666 {
1667         struct stt_dev *stt_dev = netdev_priv(dev);
1668         struct net *net = stt_dev->net;
1669
1670         list_del_rcu(&stt_dev->up_next);
1671         tcp_sock_release(stt_dev->sock);
1672         stt_dev->sock = NULL;
1673         stt_cleanup(net);
1674         return 0;
1675 }
1676
1677 static const struct net_device_ops stt_netdev_ops = {
1678         .ndo_init               = stt_init,
1679         .ndo_uninit             = stt_uninit,
1680         .ndo_open               = stt_open,
1681         .ndo_stop               = stt_stop,
1682         .ndo_start_xmit         = stt_dev_xmit,
1683         .ndo_get_stats64        = ip_tunnel_get_stats64,
1684         .ndo_change_mtu         = eth_change_mtu,
1685         .ndo_validate_addr      = eth_validate_addr,
1686         .ndo_set_mac_address    = eth_mac_addr,
1687 };
1688
1689 static void stt_get_drvinfo(struct net_device *dev,
1690                 struct ethtool_drvinfo *drvinfo)
1691 {
1692         strlcpy(drvinfo->version, STT_NETDEV_VER, sizeof(drvinfo->version));
1693         strlcpy(drvinfo->driver, "stt", sizeof(drvinfo->driver));
1694 }
1695
1696 static const struct ethtool_ops stt_ethtool_ops = {
1697         .get_drvinfo    = stt_get_drvinfo,
1698         .get_link       = ethtool_op_get_link,
1699 };
1700
1701 /* Info for udev, that this is a virtual tunnel endpoint */
1702 static struct device_type stt_type = {
1703         .name = "stt",
1704 };
1705
1706 /* Initialize the device structure. */
1707 static void stt_setup(struct net_device *dev)
1708 {
1709         ether_setup(dev);
1710
1711         dev->netdev_ops = &stt_netdev_ops;
1712         dev->ethtool_ops = &stt_ethtool_ops;
1713         dev->destructor = free_netdev;
1714
1715         SET_NETDEV_DEVTYPE(dev, &stt_type);
1716
1717         dev->features    |= NETIF_F_LLTX | NETIF_F_NETNS_LOCAL;
1718         dev->features    |= NETIF_F_SG | NETIF_F_HW_CSUM;
1719         dev->features    |= NETIF_F_RXCSUM;
1720         dev->features    |= NETIF_F_GSO_SOFTWARE;
1721
1722         dev->hw_features |= NETIF_F_SG | NETIF_F_HW_CSUM | NETIF_F_RXCSUM;
1723         dev->hw_features |= NETIF_F_GSO_SOFTWARE;
1724
1725 #ifdef HAVE_METADATA_DST
1726         netif_keep_dst(dev);
1727 #endif
1728         dev->priv_flags |= IFF_LIVE_ADDR_CHANGE | IFF_NO_QUEUE;
1729         eth_hw_addr_random(dev);
1730 }
1731
1732 static const struct nla_policy stt_policy[IFLA_STT_MAX + 1] = {
1733         [IFLA_STT_PORT]              = { .type = NLA_U16 },
1734 };
1735
1736 static int stt_validate(struct nlattr *tb[], struct nlattr *data[])
1737 {
1738         if (tb[IFLA_ADDRESS]) {
1739                 if (nla_len(tb[IFLA_ADDRESS]) != ETH_ALEN)
1740                         return -EINVAL;
1741
1742                 if (!is_valid_ether_addr(nla_data(tb[IFLA_ADDRESS])))
1743                         return -EADDRNOTAVAIL;
1744         }
1745
1746         return 0;
1747 }
1748
1749 static struct stt_dev *find_dev(struct net *net, __be16 dst_port)
1750 {
1751         struct stt_net *sn = net_generic(net, stt_net_id);
1752         struct stt_dev *dev;
1753
1754         list_for_each_entry(dev, &sn->stt_list, next) {
1755                 if (dev->dst_port == dst_port)
1756                         return dev;
1757         }
1758         return NULL;
1759 }
1760
1761 static int stt_configure(struct net *net, struct net_device *dev,
1762                           __be16 dst_port)
1763 {
1764         struct stt_net *sn = net_generic(net, stt_net_id);
1765         struct stt_dev *stt = netdev_priv(dev);
1766         int err;
1767
1768         stt->net = net;
1769         stt->dev = dev;
1770
1771         stt->dst_port = dst_port;
1772
1773         if (find_dev(net, dst_port))
1774                 return -EBUSY;
1775
1776         err = register_netdevice(dev);
1777         if (err)
1778                 return err;
1779
1780         list_add(&stt->next, &sn->stt_list);
1781         return 0;
1782 }
1783
1784 static int stt_newlink(struct net *net, struct net_device *dev,
1785                 struct nlattr *tb[], struct nlattr *data[])
1786 {
1787         __be16 dst_port = htons(STT_DST_PORT);
1788
1789         if (data[IFLA_STT_PORT])
1790                 dst_port = nla_get_be16(data[IFLA_STT_PORT]);
1791
1792         return stt_configure(net, dev, dst_port);
1793 }
1794
1795 static void stt_dellink(struct net_device *dev, struct list_head *head)
1796 {
1797         struct stt_dev *stt = netdev_priv(dev);
1798
1799         list_del(&stt->next);
1800         unregister_netdevice_queue(dev, head);
1801 }
1802
1803 static size_t stt_get_size(const struct net_device *dev)
1804 {
1805         return nla_total_size(sizeof(__be32));  /* IFLA_STT_PORT */
1806 }
1807
1808 static int stt_fill_info(struct sk_buff *skb, const struct net_device *dev)
1809 {
1810         struct stt_dev *stt = netdev_priv(dev);
1811
1812         if (nla_put_be16(skb, IFLA_STT_PORT, stt->dst_port))
1813                 goto nla_put_failure;
1814
1815         return 0;
1816
1817 nla_put_failure:
1818         return -EMSGSIZE;
1819 }
1820
1821 static struct rtnl_link_ops stt_link_ops __read_mostly = {
1822         .kind           = "stt",
1823         .maxtype        = IFLA_STT_MAX,
1824         .policy         = stt_policy,
1825         .priv_size      = sizeof(struct stt_dev),
1826         .setup          = stt_setup,
1827         .validate       = stt_validate,
1828         .newlink        = stt_newlink,
1829         .dellink        = stt_dellink,
1830         .get_size       = stt_get_size,
1831         .fill_info      = stt_fill_info,
1832 };
1833
1834 struct net_device *ovs_stt_dev_create_fb(struct net *net, const char *name,
1835                                       u8 name_assign_type, u16 dst_port)
1836 {
1837         struct nlattr *tb[IFLA_MAX + 1];
1838         struct net_device *dev;
1839         int err;
1840
1841         memset(tb, 0, sizeof(tb));
1842         dev = rtnl_create_link(net, (char *) name, name_assign_type,
1843                         &stt_link_ops, tb);
1844         if (IS_ERR(dev))
1845                 return dev;
1846
1847         err = stt_configure(net, dev, htons(dst_port));
1848         if (err) {
1849                 free_netdev(dev);
1850                 return ERR_PTR(err);
1851         }
1852         return dev;
1853 }
1854 EXPORT_SYMBOL_GPL(ovs_stt_dev_create_fb);
1855
1856 static int stt_init_net(struct net *net)
1857 {
1858         struct stt_net *sn = net_generic(net, stt_net_id);
1859
1860         INIT_LIST_HEAD(&sn->stt_list);
1861         INIT_LIST_HEAD(&sn->stt_up_list);
1862         return 0;
1863 }
1864
1865 static void stt_exit_net(struct net *net)
1866 {
1867         struct stt_net *sn = net_generic(net, stt_net_id);
1868         struct stt_dev *stt, *next;
1869         struct net_device *dev, *aux;
1870         LIST_HEAD(list);
1871
1872         rtnl_lock();
1873
1874         /* gather any stt devices that were moved into this ns */
1875         for_each_netdev_safe(net, dev, aux)
1876                 if (dev->rtnl_link_ops == &stt_link_ops)
1877                         unregister_netdevice_queue(dev, &list);
1878
1879         list_for_each_entry_safe(stt, next, &sn->stt_list, next) {
1880                 /* If stt->dev is in the same netns, it was already added
1881                  * to the stt by the previous loop.
1882                  */
1883                 if (!net_eq(dev_net(stt->dev), net))
1884                         unregister_netdevice_queue(stt->dev, &list);
1885         }
1886
1887         /* unregister the devices gathered above */
1888         unregister_netdevice_many(&list);
1889         rtnl_unlock();
1890 }
1891
1892 static struct pernet_operations stt_net_ops = {
1893         .init = stt_init_net,
1894         .exit = stt_exit_net,
1895         .id   = &stt_net_id,
1896         .size = sizeof(struct stt_net),
1897 };
1898
1899 int stt_init_module(void)
1900 {
1901         int rc;
1902
1903         rc = register_pernet_subsys(&stt_net_ops);
1904         if (rc)
1905                 goto out1;
1906
1907         rc = rtnl_link_register(&stt_link_ops);
1908         if (rc)
1909                 goto out2;
1910
1911         pr_info("STT tunneling driver\n");
1912         return 0;
1913 out2:
1914         unregister_pernet_subsys(&stt_net_ops);
1915 out1:
1916         return rc;
1917 }
1918
1919 void stt_cleanup_module(void)
1920 {
1921         rtnl_link_unregister(&stt_link_ops);
1922         unregister_pernet_subsys(&stt_net_ops);
1923 }
1924 #endif