dcd70ffc9f633c3fc09f026530b9ad6627e01231
[cascardo/ovs.git] / datapath / linux / compat / stt.c
1 /*
2  * Stateless TCP Tunnel (STT) vport.
3  *
4  * Copyright (c) 2015 Nicira, Inc.
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * as published by the Free Software Foundation; either version
9  * 2 of the License, or (at your option) any later version.
10  */
11
12 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
13 #include <asm/unaligned.h>
14
15 #include <linux/delay.h>
16 #include <linux/flex_array.h>
17 #include <linux/if.h>
18 #include <linux/if_vlan.h>
19 #include <linux/ip.h>
20 #include <linux/ipv6.h>
21 #include <linux/jhash.h>
22 #include <linux/list.h>
23 #include <linux/log2.h>
24 #include <linux/module.h>
25 #include <linux/net.h>
26 #include <linux/netfilter.h>
27 #include <linux/percpu.h>
28 #include <linux/skbuff.h>
29 #include <linux/tcp.h>
30 #include <linux/workqueue.h>
31
32 #include <net/dst_metadata.h>
33 #include <net/icmp.h>
34 #include <net/inet_ecn.h>
35 #include <net/ip.h>
36 #include <net/ip_tunnels.h>
37 #include <net/ip6_checksum.h>
38 #include <net/net_namespace.h>
39 #include <net/netns/generic.h>
40 #include <net/sock.h>
41 #include <net/stt.h>
42 #include <net/tcp.h>
43 #include <net/udp.h>
44
45 #include "gso.h"
46 #include "compat.h"
47
48 #define STT_NETDEV_VER  "0.1"
49 #define STT_DST_PORT 7471
50
51 #ifdef OVS_STT
52 #define STT_VER 0
53
54 /* @list: Per-net list of STT ports.
55  * @rcv: The callback is called on STT packet recv, STT reassembly can generate
56  * multiple packets, in this case first packet has tunnel outer header, rest
57  * of the packets are inner packet segments with no stt header.
58  * @rcv_data: user data.
59  * @sock: Fake TCP socket for the STT port.
60  */
61 struct stt_dev {
62         struct net_device       *dev;
63         struct net              *net;
64         struct list_head        next;
65         struct socket           *sock;
66         __be16                  dst_port;
67 };
68
69 #define STT_CSUM_VERIFIED       BIT(0)
70 #define STT_CSUM_PARTIAL        BIT(1)
71 #define STT_PROTO_IPV4          BIT(2)
72 #define STT_PROTO_TCP           BIT(3)
73 #define STT_PROTO_TYPES         (STT_PROTO_IPV4 | STT_PROTO_TCP)
74
75 #define SUPPORTED_GSO_TYPES (SKB_GSO_TCPV4 | SKB_GSO_UDP | SKB_GSO_DODGY | \
76                              SKB_GSO_TCPV6)
77
78 /* The length and offset of a fragment are encoded in the sequence number.
79  * STT_SEQ_LEN_SHIFT is the left shift needed to store the length.
80  * STT_SEQ_OFFSET_MASK is the mask to extract the offset.
81  */
82 #define STT_SEQ_LEN_SHIFT 16
83 #define STT_SEQ_OFFSET_MASK (BIT(STT_SEQ_LEN_SHIFT) - 1)
84
85 /* The maximum amount of memory used to store packets waiting to be reassembled
86  * on a given CPU.  Once this threshold is exceeded we will begin freeing the
87  * least recently used fragments.
88  */
89 #define REASM_HI_THRESH (4 * 1024 * 1024)
90 /* The target for the high memory evictor.  Once we have exceeded
91  * REASM_HI_THRESH, we will continue freeing fragments until we hit
92  * this limit.
93  */
94 #define REASM_LO_THRESH (3 * 1024 * 1024)
95 /* The length of time a given packet has to be reassembled from the time the
96  * first fragment arrives.  Once this limit is exceeded it becomes available
97  * for cleaning.
98  */
99 #define FRAG_EXP_TIME (30 * HZ)
100 /* Number of hash entries.  Each entry has only a single slot to hold a packet
101  * so if there are collisions, we will drop packets.  This is allocated
102  * per-cpu and each entry consists of struct pkt_frag.
103  */
104 #define FRAG_HASH_SHIFT         8
105 #define FRAG_HASH_ENTRIES       BIT(FRAG_HASH_SHIFT)
106 #define FRAG_HASH_SEGS          ((sizeof(u32) * 8) / FRAG_HASH_SHIFT)
107
108 #define CLEAN_PERCPU_INTERVAL (30 * HZ)
109
110 struct pkt_key {
111         __be32 saddr;
112         __be32 daddr;
113         __be32 pkt_seq;
114         u32 mark;
115 };
116
117 struct pkt_frag {
118         struct sk_buff *skbs;
119         unsigned long timestamp;
120         struct list_head lru_node;
121         struct pkt_key key;
122 };
123
124 struct stt_percpu {
125         struct flex_array *frag_hash;
126         struct list_head frag_lru;
127         unsigned int frag_mem_used;
128
129         /* Protect frags table. */
130         spinlock_t lock;
131 };
132
133 struct first_frag {
134         struct sk_buff *last_skb;
135         unsigned int mem_used;
136         u16 tot_len;
137         u16 rcvd_len;
138         bool set_ecn_ce;
139 };
140
141 struct frag_skb_cb {
142         u16 offset;
143
144         /* Only valid for the first skb in the chain. */
145         struct first_frag first;
146 };
147
148 #define FRAG_CB(skb) ((struct frag_skb_cb *)(skb)->cb)
149
150 /* per-network namespace private data for this module */
151 struct stt_net {
152         struct list_head stt_list;
153         int n_tunnels;
154 };
155
156 static int stt_net_id;
157
158 static struct stt_percpu __percpu *stt_percpu_data __read_mostly;
159 static u32 frag_hash_seed __read_mostly;
160
161 /* Protects sock-hash and refcounts. */
162 static DEFINE_MUTEX(stt_mutex);
163
164 static int n_tunnels;
165 static DEFINE_PER_CPU(u32, pkt_seq_counter);
166
167 static void clean_percpu(struct work_struct *work);
168 static DECLARE_DELAYED_WORK(clean_percpu_wq, clean_percpu);
169
170 static struct stt_dev *stt_find_sock(struct net *net, __be16 port)
171 {
172         struct stt_net *sn = net_generic(net, stt_net_id);
173         struct stt_dev *stt_dev;
174
175         list_for_each_entry_rcu(stt_dev, &sn->stt_list, next) {
176                 if (inet_sk(stt_dev->sock->sk)->inet_sport == port)
177                         return stt_dev;
178         }
179         return NULL;
180 }
181
182 static __be32 ack_seq(void)
183 {
184 #if NR_CPUS <= 65536
185         u32 pkt_seq, ack;
186
187         pkt_seq = this_cpu_read(pkt_seq_counter);
188         ack = pkt_seq << ilog2(NR_CPUS) | smp_processor_id();
189         this_cpu_inc(pkt_seq_counter);
190
191         return (__force __be32)ack;
192 #else
193 #error "Support for greater than 64k CPUs not implemented"
194 #endif
195 }
196
197 static int clear_gso(struct sk_buff *skb)
198 {
199         struct skb_shared_info *shinfo = skb_shinfo(skb);
200         int err;
201
202         if (shinfo->gso_type == 0 && shinfo->gso_size == 0 &&
203             shinfo->gso_segs == 0)
204                 return 0;
205
206         err = skb_unclone(skb, GFP_ATOMIC);
207         if (unlikely(err))
208                 return err;
209
210         shinfo = skb_shinfo(skb);
211         shinfo->gso_type = 0;
212         shinfo->gso_size = 0;
213         shinfo->gso_segs = 0;
214         return 0;
215 }
216
217 static struct sk_buff *normalize_frag_list(struct sk_buff *head,
218                                            struct sk_buff **skbp)
219 {
220         struct sk_buff *skb = *skbp;
221         struct sk_buff *last;
222
223         do {
224                 struct sk_buff *frags;
225
226                 if (skb_shared(skb)) {
227                         struct sk_buff *nskb = skb_clone(skb, GFP_ATOMIC);
228
229                         if (unlikely(!nskb))
230                                 return ERR_PTR(-ENOMEM);
231
232                         nskb->next = skb->next;
233                         consume_skb(skb);
234                         skb = nskb;
235                         *skbp = skb;
236                 }
237
238                 if (head) {
239                         head->len -= skb->len;
240                         head->data_len -= skb->len;
241                         head->truesize -= skb->truesize;
242                 }
243
244                 frags = skb_shinfo(skb)->frag_list;
245                 if (frags) {
246                         int err;
247
248                         err = skb_unclone(skb, GFP_ATOMIC);
249                         if (unlikely(err))
250                                 return ERR_PTR(err);
251
252                         last = normalize_frag_list(skb, &frags);
253                         if (IS_ERR(last))
254                                 return last;
255
256                         skb_shinfo(skb)->frag_list = NULL;
257                         last->next = skb->next;
258                         skb->next = frags;
259                 } else {
260                         last = skb;
261                 }
262
263                 skbp = &skb->next;
264         } while ((skb = skb->next));
265
266         return last;
267 }
268
269 /* Takes a linked list of skbs, which potentially contain frag_list
270  * (whose members in turn potentially contain frag_lists, etc.) and
271  * converts them into a single linear linked list.
272  */
273 static int straighten_frag_list(struct sk_buff **skbp)
274 {
275         struct sk_buff *err_skb;
276
277         err_skb = normalize_frag_list(NULL, skbp);
278         if (IS_ERR(err_skb))
279                 return PTR_ERR(err_skb);
280
281         return 0;
282 }
283
284 static void copy_skb_metadata(struct sk_buff *to, struct sk_buff *from)
285 {
286         to->protocol = from->protocol;
287         to->tstamp = from->tstamp;
288         to->priority = from->priority;
289         to->mark = from->mark;
290         to->vlan_tci = from->vlan_tci;
291 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
292         to->vlan_proto = from->vlan_proto;
293 #endif
294         skb_copy_secmark(to, from);
295 }
296
297 static void update_headers(struct sk_buff *skb, bool head,
298                                unsigned int l4_offset, unsigned int hdr_len,
299                                bool ipv4, u32 tcp_seq)
300 {
301         u16 old_len, new_len;
302         __be32 delta;
303         struct tcphdr *tcph;
304         int gso_size;
305
306         if (ipv4) {
307                 struct iphdr *iph = (struct iphdr *)(skb->data + ETH_HLEN);
308
309                 old_len = ntohs(iph->tot_len);
310                 new_len = skb->len - ETH_HLEN;
311                 iph->tot_len = htons(new_len);
312
313                 ip_send_check(iph);
314         } else {
315                 struct ipv6hdr *ip6h = (struct ipv6hdr *)(skb->data + ETH_HLEN);
316
317                 old_len = ntohs(ip6h->payload_len);
318                 new_len = skb->len - ETH_HLEN - sizeof(struct ipv6hdr);
319                 ip6h->payload_len = htons(new_len);
320         }
321
322         tcph = (struct tcphdr *)(skb->data + l4_offset);
323         if (!head) {
324                 tcph->seq = htonl(tcp_seq);
325                 tcph->cwr = 0;
326         }
327
328         if (skb->next) {
329                 tcph->fin = 0;
330                 tcph->psh = 0;
331         }
332
333         delta = htonl(~old_len + new_len);
334         tcph->check = ~csum_fold((__force __wsum)((__force u32)tcph->check +
335                                  (__force u32)delta));
336
337         gso_size = skb_shinfo(skb)->gso_size;
338         if (gso_size && skb->len - hdr_len <= gso_size)
339                 BUG_ON(clear_gso(skb));
340 }
341
342 static bool can_segment(struct sk_buff *head, bool ipv4, bool tcp, bool csum_partial)
343 {
344         /* If no offloading is in use then we don't have enough information
345          * to process the headers.
346          */
347         if (!csum_partial)
348                 goto linearize;
349
350         /* Handling UDP packets requires IP fragmentation, which means that
351          * the L4 checksum can no longer be calculated by hardware (since the
352          * fragments are in different packets.  If we have to compute the
353          * checksum it's faster just to linearize and large UDP packets are
354          * pretty uncommon anyways, so it's not worth dealing with for now.
355          */
356         if (!tcp)
357                 goto linearize;
358
359         if (ipv4) {
360                 struct iphdr *iph = (struct iphdr *)(head->data + ETH_HLEN);
361
362                 /* It's difficult to get the IP IDs exactly right here due to
363                  * varying segment sizes and potentially multiple layers of
364                  * segmentation.  IP ID isn't important when DF is set and DF
365                  * is generally set for TCP packets, so just linearize if it's
366                  * not.
367                  */
368                 if (!(iph->frag_off & htons(IP_DF)))
369                         goto linearize;
370         } else {
371                 struct ipv6hdr *ip6h = (struct ipv6hdr *)(head->data + ETH_HLEN);
372
373                 /* Jumbograms require more processing to update and we'll
374                  * probably never see them, so just linearize.
375                  */
376                 if (ip6h->payload_len == 0)
377                         goto linearize;
378         }
379         return true;
380
381 linearize:
382         return false;
383 }
384
385 static int copy_headers(struct sk_buff *head, struct sk_buff *frag,
386                             int hdr_len)
387 {
388         u16 csum_start;
389
390         if (skb_cloned(frag) || skb_headroom(frag) < hdr_len) {
391                 int extra_head = hdr_len - skb_headroom(frag);
392
393                 extra_head = extra_head > 0 ? extra_head : 0;
394                 if (unlikely(pskb_expand_head(frag, extra_head, 0,
395                                               GFP_ATOMIC)))
396                         return -ENOMEM;
397         }
398
399         memcpy(__skb_push(frag, hdr_len), head->data, hdr_len);
400
401         csum_start = head->csum_start - skb_headroom(head);
402         frag->csum_start = skb_headroom(frag) + csum_start;
403         frag->csum_offset = head->csum_offset;
404         frag->ip_summed = head->ip_summed;
405
406         skb_shinfo(frag)->gso_size = skb_shinfo(head)->gso_size;
407         skb_shinfo(frag)->gso_type = skb_shinfo(head)->gso_type;
408         skb_shinfo(frag)->gso_segs = 0;
409
410         copy_skb_metadata(frag, head);
411         return 0;
412 }
413
414 static int skb_list_segment(struct sk_buff *head, bool ipv4, int l4_offset)
415 {
416         struct sk_buff *skb;
417         struct tcphdr *tcph;
418         int seg_len;
419         int hdr_len;
420         int tcp_len;
421         u32 seq;
422
423         if (unlikely(!pskb_may_pull(head, l4_offset + sizeof(*tcph))))
424                 return -ENOMEM;
425
426         tcph = (struct tcphdr *)(head->data + l4_offset);
427         tcp_len = tcph->doff * 4;
428         hdr_len = l4_offset + tcp_len;
429
430         if (unlikely((tcp_len < sizeof(struct tcphdr)) ||
431                      (head->len < hdr_len)))
432                 return -EINVAL;
433
434         if (unlikely(!pskb_may_pull(head, hdr_len)))
435                 return -ENOMEM;
436
437         tcph = (struct tcphdr *)(head->data + l4_offset);
438         /* Update header of each segment. */
439         seq = ntohl(tcph->seq);
440         seg_len = skb_pagelen(head) - hdr_len;
441
442         skb = skb_shinfo(head)->frag_list;
443         skb_shinfo(head)->frag_list = NULL;
444         head->next = skb;
445         for (; skb; skb = skb->next) {
446                 int err;
447
448                 head->len -= skb->len;
449                 head->data_len -= skb->len;
450                 head->truesize -= skb->truesize;
451
452                 seq += seg_len;
453                 seg_len = skb->len;
454                 err = copy_headers(head, skb, hdr_len);
455                 if (err)
456                         return err;
457                 update_headers(skb, false, l4_offset, hdr_len, ipv4, seq);
458         }
459         update_headers(head, true, l4_offset, hdr_len, ipv4, 0);
460         return 0;
461 }
462
463 static int coalesce_skb(struct sk_buff **headp)
464 {
465         struct sk_buff *frag, *head, *prev;
466         int err;
467
468         err = straighten_frag_list(headp);
469         if (unlikely(err))
470                 return err;
471         head = *headp;
472
473         /* Coalesce frag list. */
474         prev = head;
475         for (frag = head->next; frag; frag = frag->next) {
476                 bool headstolen;
477                 int delta;
478
479                 if (unlikely(skb_unclone(prev, GFP_ATOMIC)))
480                         return -ENOMEM;
481
482                 if (!skb_try_coalesce(prev, frag, &headstolen, &delta)) {
483                         prev = frag;
484                         continue;
485                 }
486
487                 prev->next = frag->next;
488                 frag->len = 0;
489                 frag->data_len = 0;
490                 frag->truesize -= delta;
491                 kfree_skb_partial(frag, headstolen);
492                 frag = prev;
493         }
494
495         if (!head->next)
496                 return 0;
497
498         for (frag = head->next; frag; frag = frag->next) {
499                 head->len += frag->len;
500                 head->data_len += frag->len;
501                 head->truesize += frag->truesize;
502         }
503
504         skb_shinfo(head)->frag_list = head->next;
505         head->next = NULL;
506         return 0;
507 }
508
509 static int __try_to_segment(struct sk_buff *skb, bool csum_partial,
510                             bool ipv4, bool tcp, int l4_offset)
511 {
512         if (can_segment(skb, ipv4, tcp, csum_partial))
513                 return skb_list_segment(skb, ipv4, l4_offset);
514         else
515                 return skb_linearize(skb);
516 }
517
518 static int try_to_segment(struct sk_buff *skb)
519 {
520         struct stthdr *stth = stt_hdr(skb);
521         bool csum_partial = !!(stth->flags & STT_CSUM_PARTIAL);
522         bool ipv4 = !!(stth->flags & STT_PROTO_IPV4);
523         bool tcp = !!(stth->flags & STT_PROTO_TCP);
524         int l4_offset = stth->l4_offset;
525
526         return __try_to_segment(skb, csum_partial, ipv4, tcp, l4_offset);
527 }
528
529 static int segment_skb(struct sk_buff **headp, bool csum_partial,
530                        bool ipv4, bool tcp, int l4_offset)
531 {
532         int err;
533
534         err = coalesce_skb(headp);
535         if (err)
536                 return err;
537
538         if (skb_shinfo(*headp)->frag_list)
539                 return __try_to_segment(*headp, csum_partial,
540                                         ipv4, tcp, l4_offset);
541         return 0;
542 }
543
544 static int __push_stt_header(struct sk_buff *skb, __be64 tun_id,
545                              __be16 s_port, __be16 d_port,
546                              __be32 saddr, __be32 dst,
547                              __be16 l3_proto, u8 l4_proto,
548                              int dst_mtu)
549 {
550         int data_len = skb->len + sizeof(struct stthdr) + STT_ETH_PAD;
551         unsigned short encap_mss;
552         struct tcphdr *tcph;
553         struct stthdr *stth;
554
555         skb_push(skb, STT_HEADER_LEN);
556         skb_reset_transport_header(skb);
557         tcph = tcp_hdr(skb);
558         memset(tcph, 0, STT_HEADER_LEN);
559         stth = stt_hdr(skb);
560
561         if (skb->ip_summed == CHECKSUM_PARTIAL) {
562                 stth->flags |= STT_CSUM_PARTIAL;
563
564                 stth->l4_offset = skb->csum_start -
565                                         (skb_headroom(skb) +
566                                         STT_HEADER_LEN);
567
568                 if (l3_proto == htons(ETH_P_IP))
569                         stth->flags |= STT_PROTO_IPV4;
570
571                 if (l4_proto == IPPROTO_TCP)
572                         stth->flags |= STT_PROTO_TCP;
573
574                 stth->mss = htons(skb_shinfo(skb)->gso_size);
575         } else if (skb->ip_summed == CHECKSUM_UNNECESSARY) {
576                 stth->flags |= STT_CSUM_VERIFIED;
577         }
578
579         stth->vlan_tci = htons(skb->vlan_tci);
580         skb->vlan_tci = 0;
581         put_unaligned(tun_id, &stth->key);
582
583         tcph->source    = s_port;
584         tcph->dest      = d_port;
585         tcph->doff      = sizeof(struct tcphdr) / 4;
586         tcph->ack       = 1;
587         tcph->psh       = 1;
588         tcph->window    = htons(USHRT_MAX);
589         tcph->seq       = htonl(data_len << STT_SEQ_LEN_SHIFT);
590         tcph->ack_seq   = ack_seq();
591         tcph->check     = ~tcp_v4_check(skb->len, saddr, dst, 0);
592
593         skb->csum_start = skb_transport_header(skb) - skb->head;
594         skb->csum_offset = offsetof(struct tcphdr, check);
595         skb->ip_summed = CHECKSUM_PARTIAL;
596
597         encap_mss = dst_mtu - sizeof(struct iphdr) - sizeof(struct tcphdr);
598         if (data_len > encap_mss) {
599                 if (unlikely(skb_unclone(skb, GFP_ATOMIC)))
600                         return -EINVAL;
601
602                 skb_shinfo(skb)->gso_type = SKB_GSO_TCPV4;
603                 skb_shinfo(skb)->gso_size = encap_mss;
604                 skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(data_len, encap_mss);
605         } else {
606                 if (unlikely(clear_gso(skb)))
607                         return -EINVAL;
608         }
609         return 0;
610 }
611
612 static struct sk_buff *push_stt_header(struct sk_buff *head, __be64 tun_id,
613                                        __be16 s_port, __be16 d_port,
614                                        __be32 saddr, __be32 dst,
615                                        __be16 l3_proto, u8 l4_proto,
616                                        int dst_mtu)
617 {
618         struct sk_buff *skb;
619
620         if (skb_shinfo(head)->frag_list) {
621                 bool ipv4 = (l3_proto == htons(ETH_P_IP));
622                 bool tcp = (l4_proto == IPPROTO_TCP);
623                 bool csum_partial = (head->ip_summed == CHECKSUM_PARTIAL);
624                 int l4_offset = skb_transport_offset(head);
625
626                 /* Need to call skb_orphan() to report currect true-size.
627                  * calling skb_orphan() in this layer is odd but SKB with
628                  * frag-list should not be associated with any socket, so
629                  * skb-orphan should be no-op. */
630                 skb_orphan(head);
631                 if (unlikely(segment_skb(&head, csum_partial,
632                                          ipv4, tcp, l4_offset)))
633                         goto error;
634         }
635
636         for (skb = head; skb; skb = skb->next) {
637                 if (__push_stt_header(skb, tun_id, s_port, d_port, saddr, dst,
638                                       l3_proto, l4_proto, dst_mtu))
639                         goto error;
640         }
641
642         return head;
643 error:
644         kfree_skb_list(head);
645         return NULL;
646 }
647
648 static int stt_can_offload(struct sk_buff *skb, __be16 l3_proto, u8 l4_proto)
649 {
650         if (skb_is_gso(skb) && skb->ip_summed != CHECKSUM_PARTIAL) {
651                 int csum_offset;
652                 __sum16 *csum;
653                 int len;
654
655                 if (l4_proto == IPPROTO_TCP)
656                         csum_offset = offsetof(struct tcphdr, check);
657                 else if (l4_proto == IPPROTO_UDP)
658                         csum_offset = offsetof(struct udphdr, check);
659                 else
660                         return 0;
661
662                 len = skb->len - skb_transport_offset(skb);
663                 csum = (__sum16 *)(skb_transport_header(skb) + csum_offset);
664
665                 if (unlikely(!pskb_may_pull(skb, skb_transport_offset(skb) +
666                                                  csum_offset + sizeof(*csum))))
667                         return -EINVAL;
668
669                 if (l3_proto == htons(ETH_P_IP)) {
670                         struct iphdr *iph = ip_hdr(skb);
671
672                         *csum = ~csum_tcpudp_magic(iph->saddr, iph->daddr,
673                                                    len, l4_proto, 0);
674                 } else if (l3_proto == htons(ETH_P_IPV6)) {
675                         struct ipv6hdr *ip6h = ipv6_hdr(skb);
676
677                         *csum = ~csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
678                                                  len, l4_proto, 0);
679                 } else {
680                         return 0;
681                 }
682                 skb->csum_start = skb_transport_header(skb) - skb->head;
683                 skb->csum_offset = csum_offset;
684                 skb->ip_summed = CHECKSUM_PARTIAL;
685         }
686
687         if (skb->ip_summed == CHECKSUM_PARTIAL) {
688                 /* Assume receiver can only offload TCP/UDP over IPv4/6,
689                  * and require 802.1Q VLANs to be accelerated.
690                  */
691                 if (l3_proto != htons(ETH_P_IP) &&
692                     l3_proto != htons(ETH_P_IPV6))
693                         return 0;
694
695                 if (l4_proto != IPPROTO_TCP && l4_proto != IPPROTO_UDP)
696                         return 0;
697
698                 /* L4 offset must fit in a 1-byte field. */
699                 if (skb->csum_start - skb_headroom(skb) > 255)
700                         return 0;
701
702                 if (skb_shinfo(skb)->gso_type & ~SUPPORTED_GSO_TYPES)
703                         return 0;
704         }
705         /* Total size of encapsulated packet must fit in 16 bits. */
706         if (skb->len + STT_HEADER_LEN + sizeof(struct iphdr) > 65535)
707                 return 0;
708
709 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
710         if (skb_vlan_tag_present(skb) && skb->vlan_proto != htons(ETH_P_8021Q))
711                 return 0;
712 #endif
713         return 1;
714 }
715
716 static bool need_linearize(const struct sk_buff *skb)
717 {
718         struct skb_shared_info *shinfo = skb_shinfo(skb);
719         int i;
720
721         if (unlikely(shinfo->frag_list))
722                 return true;
723
724         /* Generally speaking we should linearize if there are paged frags.
725          * However, if all of the refcounts are 1 we know nobody else can
726          * change them from underneath us and we can skip the linearization.
727          */
728         for (i = 0; i < shinfo->nr_frags; i++)
729                 if (unlikely(page_count(skb_frag_page(&shinfo->frags[i])) > 1))
730                         return true;
731
732         return false;
733 }
734
735 static struct sk_buff *handle_offloads(struct sk_buff *skb, int min_headroom)
736 {
737         int err;
738
739 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
740         if (skb_vlan_tag_present(skb) && skb->vlan_proto != htons(ETH_P_8021Q)) {
741
742                 min_headroom += VLAN_HLEN;
743                 if (skb_headroom(skb) < min_headroom) {
744                         int head_delta = SKB_DATA_ALIGN(min_headroom -
745                                                         skb_headroom(skb) + 16);
746
747                         err = pskb_expand_head(skb, max_t(int, head_delta, 0),
748                                                0, GFP_ATOMIC);
749                         if (unlikely(err))
750                                 goto error;
751                 }
752
753                 skb = __vlan_hwaccel_push_inside(skb);
754                 if (!skb) {
755                         err = -ENOMEM;
756                         goto error;
757                 }
758         }
759 #endif
760
761         if (skb_is_gso(skb)) {
762                 struct sk_buff *nskb;
763                 char cb[sizeof(skb->cb)];
764
765                 memcpy(cb, skb->cb, sizeof(cb));
766
767                 nskb = __skb_gso_segment(skb, 0, false);
768                 if (IS_ERR(nskb)) {
769                         err = PTR_ERR(nskb);
770                         goto error;
771                 }
772
773                 consume_skb(skb);
774                 skb = nskb;
775                 while (nskb) {
776                         memcpy(nskb->cb, cb, sizeof(cb));
777                         nskb = nskb->next;
778                 }
779         } else if (skb->ip_summed == CHECKSUM_PARTIAL) {
780                 /* Pages aren't locked and could change at any time.
781                  * If this happens after we compute the checksum, the
782                  * checksum will be wrong.  We linearize now to avoid
783                  * this problem.
784                  */
785                 if (unlikely(need_linearize(skb))) {
786                         err = __skb_linearize(skb);
787                         if (unlikely(err))
788                                 goto error;
789                 }
790
791                 err = skb_checksum_help(skb);
792                 if (unlikely(err))
793                         goto error;
794         }
795         skb->ip_summed = CHECKSUM_NONE;
796
797         return skb;
798 error:
799         kfree_skb(skb);
800         return ERR_PTR(err);
801 }
802
803 static int skb_list_xmit(struct rtable *rt, struct sk_buff *skb, __be32 src,
804                          __be32 dst, __u8 tos, __u8 ttl, __be16 df)
805 {
806         int len = 0;
807
808         while (skb) {
809                 struct sk_buff *next = skb->next;
810
811                 if (next)
812                         dst_clone(&rt->dst);
813
814                 skb->next = NULL;
815                 len += iptunnel_xmit(NULL, rt, skb, src, dst, IPPROTO_TCP,
816                                      tos, ttl, df, false);
817
818                 skb = next;
819         }
820         return len;
821 }
822
823 static u8 parse_ipv6_l4_proto(struct sk_buff *skb)
824 {
825         unsigned int nh_ofs = skb_network_offset(skb);
826         int payload_ofs;
827         struct ipv6hdr *nh;
828         uint8_t nexthdr;
829         __be16 frag_off;
830
831         if (unlikely(!pskb_may_pull(skb, nh_ofs + sizeof(struct ipv6hdr))))
832                 return 0;
833
834         nh = ipv6_hdr(skb);
835         nexthdr = nh->nexthdr;
836         payload_ofs = (u8 *)(nh + 1) - skb->data;
837
838         payload_ofs = ipv6_skip_exthdr(skb, payload_ofs, &nexthdr, &frag_off);
839         if (unlikely(payload_ofs < 0))
840                 return 0;
841
842         return nexthdr;
843 }
844
845 static u8 skb_get_l4_proto(struct sk_buff *skb, __be16 l3_proto)
846 {
847         if (l3_proto == htons(ETH_P_IP)) {
848                 unsigned int nh_ofs = skb_network_offset(skb);
849
850                 if (unlikely(!pskb_may_pull(skb, nh_ofs + sizeof(struct iphdr))))
851                         return 0;
852
853                 return ip_hdr(skb)->protocol;
854         } else if (l3_proto == htons(ETH_P_IPV6)) {
855                 return parse_ipv6_l4_proto(skb);
856         }
857         return 0;
858 }
859
860 static int stt_xmit_skb(struct sk_buff *skb, struct rtable *rt,
861                  __be32 src, __be32 dst, __u8 tos,
862                  __u8 ttl, __be16 df, __be16 src_port, __be16 dst_port,
863                  __be64 tun_id)
864 {
865         struct ethhdr *eh = eth_hdr(skb);
866         int ret = 0, min_headroom;
867         __be16 inner_l3_proto;
868          u8 inner_l4_proto;
869
870         inner_l3_proto = eh->h_proto;
871         inner_l4_proto = skb_get_l4_proto(skb, inner_l3_proto);
872
873         min_headroom = LL_RESERVED_SPACE(rt->dst.dev) + rt->dst.header_len
874                         + STT_HEADER_LEN + sizeof(struct iphdr);
875
876         if (skb_headroom(skb) < min_headroom || skb_header_cloned(skb)) {
877                 int head_delta = SKB_DATA_ALIGN(min_headroom -
878                                                 skb_headroom(skb) +
879                                                 16);
880
881                 ret = pskb_expand_head(skb, max_t(int, head_delta, 0),
882                                        0, GFP_ATOMIC);
883                 if (unlikely(ret))
884                         goto err_free_rt;
885         }
886
887         ret = stt_can_offload(skb, inner_l3_proto, inner_l4_proto);
888         if (ret < 0)
889                 goto err_free_rt;
890         if (!ret) {
891                 skb = handle_offloads(skb, min_headroom);
892                 if (IS_ERR(skb)) {
893                         ret = PTR_ERR(skb);
894                         skb = NULL;
895                         goto err_free_rt;
896                 }
897         }
898
899         ret = 0;
900         while (skb) {
901                 struct sk_buff *next_skb = skb->next;
902
903                 skb->next = NULL;
904
905                 if (next_skb)
906                         dst_clone(&rt->dst);
907
908                 /* Push STT and TCP header. */
909                 skb = push_stt_header(skb, tun_id, src_port, dst_port, src,
910                                       dst, inner_l3_proto, inner_l4_proto,
911                                       dst_mtu(&rt->dst));
912                 if (unlikely(!skb)) {
913                         ip_rt_put(rt);
914                         goto next;
915                 }
916
917                 /* Push IP header. */
918                 ret += skb_list_xmit(rt, skb, src, dst, tos, ttl, df);
919
920 next:
921                 skb = next_skb;
922         }
923
924         return ret;
925
926 err_free_rt:
927         ip_rt_put(rt);
928         kfree_skb(skb);
929         return ret;
930 }
931
932 netdev_tx_t ovs_stt_xmit(struct sk_buff *skb)
933 {
934         struct net_device *dev = skb->dev;
935         struct stt_dev *stt_dev = netdev_priv(dev);
936         struct net *net = stt_dev->net;
937         __be16 dport = inet_sk(stt_dev->sock->sk)->inet_sport;
938         struct ip_tunnel_key *tun_key;
939         struct ip_tunnel_info *tun_info;
940         struct rtable *rt;
941         struct flowi4 fl;
942         __be16 sport;
943         __be16 df;
944         int err;
945
946         tun_info = skb_tunnel_info(skb);
947         if (unlikely(!tun_info)) {
948                 err = -EINVAL;
949                 goto error;
950         }
951
952         tun_key = &tun_info->key;
953
954         /* Route lookup */
955         memset(&fl, 0, sizeof(fl));
956         fl.daddr = tun_key->u.ipv4.dst;
957         fl.saddr = tun_key->u.ipv4.src;
958         fl.flowi4_tos = RT_TOS(tun_key->tos);
959         fl.flowi4_mark = skb->mark;
960         fl.flowi4_proto = IPPROTO_TCP;
961         rt = ip_route_output_key(net, &fl);
962         if (IS_ERR(rt)) {
963                 err = PTR_ERR(rt);
964                 goto error;
965         }
966
967         df = tun_key->tun_flags & TUNNEL_DONT_FRAGMENT ? htons(IP_DF) : 0;
968         sport = udp_flow_src_port(net, skb, 1, USHRT_MAX, true);
969         skb->ignore_df = 1;
970
971         err = stt_xmit_skb(skb, rt, fl.saddr, tun_key->u.ipv4.dst,
972                             tun_key->tos, tun_key->ttl,
973                             df, sport, dport, tun_key->tun_id);
974         iptunnel_xmit_stats(err, &dev->stats, (struct pcpu_sw_netstats __percpu *)dev->tstats);
975         return NETDEV_TX_OK;
976 error:
977         kfree_skb(skb);
978         dev->stats.tx_errors++;
979         return NETDEV_TX_OK;
980 }
981 EXPORT_SYMBOL(ovs_stt_xmit);
982
983 static void free_frag(struct stt_percpu *stt_percpu,
984                       struct pkt_frag *frag)
985 {
986         stt_percpu->frag_mem_used -= FRAG_CB(frag->skbs)->first.mem_used;
987         kfree_skb_list(frag->skbs);
988         list_del(&frag->lru_node);
989         frag->skbs = NULL;
990 }
991
992 static void evict_frags(struct stt_percpu *stt_percpu)
993 {
994         while (!list_empty(&stt_percpu->frag_lru) &&
995                stt_percpu->frag_mem_used > REASM_LO_THRESH) {
996                 struct pkt_frag *frag;
997
998                 frag = list_first_entry(&stt_percpu->frag_lru,
999                                         struct pkt_frag,
1000                                         lru_node);
1001                 free_frag(stt_percpu, frag);
1002         }
1003 }
1004
1005 static bool pkt_key_match(struct net *net,
1006                           const struct pkt_frag *a, const struct pkt_key *b)
1007 {
1008         return a->key.saddr == b->saddr && a->key.daddr == b->daddr &&
1009                a->key.pkt_seq == b->pkt_seq && a->key.mark == b->mark &&
1010                net_eq(dev_net(a->skbs->dev), net);
1011 }
1012
1013 static u32 pkt_key_hash(const struct net *net, const struct pkt_key *key)
1014 {
1015         u32 initval = frag_hash_seed ^ (u32)(unsigned long)net ^ key->mark;
1016
1017         return jhash_3words((__force u32)key->saddr, (__force u32)key->daddr,
1018                             (__force u32)key->pkt_seq, initval);
1019 }
1020
1021 static struct pkt_frag *lookup_frag(struct net *net,
1022                                     struct stt_percpu *stt_percpu,
1023                                     const struct pkt_key *key, u32 hash)
1024 {
1025         struct pkt_frag *frag, *victim_frag = NULL;
1026         int i;
1027
1028         for (i = 0; i < FRAG_HASH_SEGS; i++) {
1029                 frag = flex_array_get(stt_percpu->frag_hash,
1030                                       hash & (FRAG_HASH_ENTRIES - 1));
1031
1032                 if (frag->skbs &&
1033                     time_before(jiffies, frag->timestamp + FRAG_EXP_TIME) &&
1034                     pkt_key_match(net, frag, key))
1035                         return frag;
1036
1037                 if (!victim_frag ||
1038                     (victim_frag->skbs &&
1039                      (!frag->skbs ||
1040                       time_before(frag->timestamp, victim_frag->timestamp))))
1041                         victim_frag = frag;
1042
1043                 hash >>= FRAG_HASH_SHIFT;
1044         }
1045
1046         if (victim_frag->skbs)
1047                 free_frag(stt_percpu, victim_frag);
1048
1049         return victim_frag;
1050 }
1051
1052 static struct sk_buff *reassemble(struct sk_buff *skb)
1053 {
1054         struct iphdr *iph = ip_hdr(skb);
1055         struct tcphdr *tcph = tcp_hdr(skb);
1056         u32 seq = ntohl(tcph->seq);
1057         struct stt_percpu *stt_percpu;
1058         struct sk_buff *last_skb;
1059         struct pkt_frag *frag;
1060         struct pkt_key key;
1061         int tot_len;
1062         u32 hash;
1063
1064         tot_len = seq >> STT_SEQ_LEN_SHIFT;
1065         FRAG_CB(skb)->offset = seq & STT_SEQ_OFFSET_MASK;
1066
1067         if (unlikely(skb->len == 0))
1068                 goto out_free;
1069
1070         if (unlikely(FRAG_CB(skb)->offset + skb->len > tot_len))
1071                 goto out_free;
1072
1073         if (tot_len == skb->len)
1074                 goto out;
1075
1076         key.saddr = iph->saddr;
1077         key.daddr = iph->daddr;
1078         key.pkt_seq = tcph->ack_seq;
1079         key.mark = skb->mark;
1080         hash = pkt_key_hash(dev_net(skb->dev), &key);
1081
1082         stt_percpu = per_cpu_ptr(stt_percpu_data, smp_processor_id());
1083
1084         spin_lock(&stt_percpu->lock);
1085
1086         if (unlikely(stt_percpu->frag_mem_used + skb->truesize > REASM_HI_THRESH))
1087                 evict_frags(stt_percpu);
1088
1089         frag = lookup_frag(dev_net(skb->dev), stt_percpu, &key, hash);
1090         if (!frag->skbs) {
1091                 frag->skbs = skb;
1092                 frag->key = key;
1093                 frag->timestamp = jiffies;
1094                 FRAG_CB(skb)->first.last_skb = skb;
1095                 FRAG_CB(skb)->first.mem_used = skb->truesize;
1096                 FRAG_CB(skb)->first.tot_len = tot_len;
1097                 FRAG_CB(skb)->first.rcvd_len = skb->len;
1098                 FRAG_CB(skb)->first.set_ecn_ce = false;
1099                 list_add_tail(&frag->lru_node, &stt_percpu->frag_lru);
1100                 stt_percpu->frag_mem_used += skb->truesize;
1101
1102                 skb = NULL;
1103                 goto unlock;
1104         }
1105
1106         /* Optimize for the common case where fragments are received in-order
1107          * and not overlapping.
1108          */
1109         last_skb = FRAG_CB(frag->skbs)->first.last_skb;
1110         if (likely(FRAG_CB(last_skb)->offset + last_skb->len ==
1111                    FRAG_CB(skb)->offset)) {
1112                 last_skb->next = skb;
1113                 FRAG_CB(frag->skbs)->first.last_skb = skb;
1114         } else {
1115                 struct sk_buff *prev = NULL, *next;
1116
1117                 for (next = frag->skbs; next; next = next->next) {
1118                         if (FRAG_CB(next)->offset >= FRAG_CB(skb)->offset)
1119                                 break;
1120                         prev = next;
1121                 }
1122
1123                 /* Overlapping fragments aren't allowed.  We shouldn't start
1124                  * before the end of the previous fragment.
1125                  */
1126                 if (prev &&
1127                     FRAG_CB(prev)->offset + prev->len > FRAG_CB(skb)->offset)
1128                         goto unlock_free;
1129
1130                 /* We also shouldn't end after the beginning of the next
1131                  * fragment.
1132                  */
1133                 if (next &&
1134                     FRAG_CB(skb)->offset + skb->len > FRAG_CB(next)->offset)
1135                         goto unlock_free;
1136
1137                 if (prev) {
1138                         prev->next = skb;
1139                 } else {
1140                         FRAG_CB(skb)->first = FRAG_CB(frag->skbs)->first;
1141                         frag->skbs = skb;
1142                 }
1143
1144                 if (next)
1145                         skb->next = next;
1146                 else
1147                         FRAG_CB(frag->skbs)->first.last_skb = skb;
1148         }
1149
1150         FRAG_CB(frag->skbs)->first.set_ecn_ce |= INET_ECN_is_ce(iph->tos);
1151         FRAG_CB(frag->skbs)->first.rcvd_len += skb->len;
1152         FRAG_CB(frag->skbs)->first.mem_used += skb->truesize;
1153         stt_percpu->frag_mem_used += skb->truesize;
1154
1155         if (FRAG_CB(frag->skbs)->first.tot_len ==
1156             FRAG_CB(frag->skbs)->first.rcvd_len) {
1157                 struct sk_buff *frag_head = frag->skbs;
1158
1159                 frag_head->tstamp = skb->tstamp;
1160                 if (FRAG_CB(frag_head)->first.set_ecn_ce)
1161                         INET_ECN_set_ce(frag_head);
1162
1163                 list_del(&frag->lru_node);
1164                 stt_percpu->frag_mem_used -= FRAG_CB(frag_head)->first.mem_used;
1165                 frag->skbs = NULL;
1166                 skb = frag_head;
1167         } else {
1168                 list_move_tail(&frag->lru_node, &stt_percpu->frag_lru);
1169                 skb = NULL;
1170         }
1171
1172         goto unlock;
1173
1174 unlock_free:
1175         kfree_skb(skb);
1176         skb = NULL;
1177 unlock:
1178         spin_unlock(&stt_percpu->lock);
1179         return skb;
1180 out_free:
1181         kfree_skb(skb);
1182         skb = NULL;
1183 out:
1184         return skb;
1185 }
1186
1187 static bool validate_checksum(struct sk_buff *skb)
1188 {
1189         struct iphdr *iph = ip_hdr(skb);
1190
1191         if (skb_csum_unnecessary(skb))
1192                 return true;
1193
1194         if (skb->ip_summed == CHECKSUM_COMPLETE &&
1195             !tcp_v4_check(skb->len, iph->saddr, iph->daddr, skb->csum))
1196                 return true;
1197
1198         skb->csum = csum_tcpudp_nofold(iph->saddr, iph->daddr, skb->len,
1199                                        IPPROTO_TCP, 0);
1200
1201         return __tcp_checksum_complete(skb) == 0;
1202 }
1203
1204 static bool set_offloads(struct sk_buff *skb)
1205 {
1206         struct stthdr *stth = stt_hdr(skb);
1207         unsigned short gso_type;
1208         int l3_header_size;
1209         int l4_header_size;
1210         u16 csum_offset;
1211         u8 proto_type;
1212
1213         if (stth->vlan_tci)
1214                 __vlan_hwaccel_put_tag(skb, htons(ETH_P_8021Q),
1215                                        ntohs(stth->vlan_tci));
1216
1217         if (!(stth->flags & STT_CSUM_PARTIAL)) {
1218                 if (stth->flags & STT_CSUM_VERIFIED)
1219                         skb->ip_summed = CHECKSUM_UNNECESSARY;
1220                 else
1221                         skb->ip_summed = CHECKSUM_NONE;
1222
1223                 return clear_gso(skb) == 0;
1224         }
1225
1226         proto_type = stth->flags & STT_PROTO_TYPES;
1227
1228         switch (proto_type) {
1229         case (STT_PROTO_IPV4 | STT_PROTO_TCP):
1230                 /* TCP/IPv4 */
1231                 csum_offset = offsetof(struct tcphdr, check);
1232                 gso_type = SKB_GSO_TCPV4;
1233                 l3_header_size = sizeof(struct iphdr);
1234                 l4_header_size = sizeof(struct tcphdr);
1235                 skb->protocol = htons(ETH_P_IP);
1236                 break;
1237         case STT_PROTO_TCP:
1238                 /* TCP/IPv6 */
1239                 csum_offset = offsetof(struct tcphdr, check);
1240                 gso_type = SKB_GSO_TCPV6;
1241                 l3_header_size = sizeof(struct ipv6hdr);
1242                 l4_header_size = sizeof(struct tcphdr);
1243                 skb->protocol = htons(ETH_P_IPV6);
1244                 break;
1245         case STT_PROTO_IPV4:
1246                 /* UDP/IPv4 */
1247                 csum_offset = offsetof(struct udphdr, check);
1248                 gso_type = SKB_GSO_UDP;
1249                 l3_header_size = sizeof(struct iphdr);
1250                 l4_header_size = sizeof(struct udphdr);
1251                 skb->protocol = htons(ETH_P_IP);
1252                 break;
1253         default:
1254                 /* UDP/IPv6 */
1255                 csum_offset = offsetof(struct udphdr, check);
1256                 gso_type = SKB_GSO_UDP;
1257                 l3_header_size = sizeof(struct ipv6hdr);
1258                 l4_header_size = sizeof(struct udphdr);
1259                 skb->protocol = htons(ETH_P_IPV6);
1260         }
1261
1262         if (unlikely(stth->l4_offset < ETH_HLEN + l3_header_size))
1263                 return false;
1264
1265         if (unlikely(!pskb_may_pull(skb, stth->l4_offset + l4_header_size)))
1266                 return false;
1267
1268         stth = stt_hdr(skb);
1269
1270         skb->csum_start = skb_headroom(skb) + stth->l4_offset;
1271         skb->csum_offset = csum_offset;
1272         skb->ip_summed = CHECKSUM_PARTIAL;
1273
1274         if (stth->mss) {
1275                 if (unlikely(skb_unclone(skb, GFP_ATOMIC)))
1276                         return false;
1277
1278                 skb_shinfo(skb)->gso_type = gso_type | SKB_GSO_DODGY;
1279                 skb_shinfo(skb)->gso_size = ntohs(stth->mss);
1280                 skb_shinfo(skb)->gso_segs = 0;
1281         } else {
1282                 if (unlikely(clear_gso(skb)))
1283                         return false;
1284         }
1285
1286         return true;
1287 }
1288
1289 static void rcv_list(struct net_device *dev, struct sk_buff *skb,
1290                      struct metadata_dst *tun_dst)
1291 {
1292         struct sk_buff *next;
1293
1294         do {
1295                 next = skb->next;
1296                 skb->next = NULL;
1297                 if (next) {
1298                         ovs_dst_hold((struct dst_entry *)tun_dst);
1299                         ovs_skb_dst_set(next, (struct dst_entry *)tun_dst);
1300                 }
1301                 ovs_ip_tunnel_rcv(dev, skb, tun_dst);
1302         } while ((skb = next));
1303 }
1304
1305 #ifndef HAVE_METADATA_DST
1306 static int __stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
1307 {
1308         struct metadata_dst tun_dst;
1309
1310         ovs_ip_tun_rx_dst(&tun_dst.u.tun_info, skb, TUNNEL_KEY | TUNNEL_CSUM,
1311                           get_unaligned(&stt_hdr(skb)->key), 0);
1312         tun_dst.u.tun_info.key.tp_src = tcp_hdr(skb)->source;
1313         tun_dst.u.tun_info.key.tp_dst = tcp_hdr(skb)->dest;
1314
1315         rcv_list(stt_dev->dev, skb, &tun_dst);
1316         return 0;
1317 }
1318 #else
1319 static int __stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
1320 {
1321         struct metadata_dst *tun_dst;
1322         __be16 flags;
1323         __be64 tun_id;
1324
1325         flags = TUNNEL_KEY | TUNNEL_CSUM;
1326         tun_id = get_unaligned(&stt_hdr(skb)->key);
1327         tun_dst = ip_tun_rx_dst(skb, flags, tun_id, 0);
1328         if (!tun_dst)
1329                 return -ENOMEM;
1330         tun_dst->u.tun_info.key.tp_src = tcp_hdr(skb)->source;
1331         tun_dst->u.tun_info.key.tp_dst = tcp_hdr(skb)->dest;
1332
1333         rcv_list(stt_dev->dev, skb, tun_dst);
1334         return 0;
1335 }
1336 #endif
1337
1338 static void stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
1339 {
1340         int err;
1341
1342         if (unlikely(!validate_checksum(skb)))
1343                 goto drop;
1344
1345         skb = reassemble(skb);
1346         if (!skb)
1347                 return;
1348
1349         if (skb->next && coalesce_skb(&skb))
1350                 goto drop;
1351
1352         err = iptunnel_pull_header(skb,
1353                                    sizeof(struct stthdr) + STT_ETH_PAD,
1354                                    htons(ETH_P_TEB));
1355         if (unlikely(err))
1356                 goto drop;
1357
1358         if (unlikely(stt_hdr(skb)->version != 0))
1359                 goto drop;
1360
1361         if (unlikely(!set_offloads(skb)))
1362                 goto drop;
1363
1364         if (skb_shinfo(skb)->frag_list && try_to_segment(skb))
1365                 goto drop;
1366
1367         err = __stt_rcv(stt_dev, skb);
1368         if (err)
1369                 goto drop;
1370         return;
1371 drop:
1372         /* Consume bad packet */
1373         kfree_skb_list(skb);
1374         stt_dev->dev->stats.rx_errors++;
1375 }
1376
1377 static void tcp_sock_release(struct socket *sock)
1378 {
1379         kernel_sock_shutdown(sock, SHUT_RDWR);
1380         sock_release(sock);
1381 }
1382
1383 static int tcp_sock_create4(struct net *net, __be16 port,
1384                             struct socket **sockp)
1385 {
1386         struct sockaddr_in tcp_addr;
1387         struct socket *sock = NULL;
1388         int err;
1389
1390         err = sock_create_kern(net, AF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
1391         if (err < 0)
1392                 goto error;
1393
1394         memset(&tcp_addr, 0, sizeof(tcp_addr));
1395         tcp_addr.sin_family = AF_INET;
1396         tcp_addr.sin_addr.s_addr = htonl(INADDR_ANY);
1397         tcp_addr.sin_port = port;
1398         err = kernel_bind(sock, (struct sockaddr *)&tcp_addr,
1399                           sizeof(tcp_addr));
1400         if (err < 0)
1401                 goto error;
1402
1403         *sockp = sock;
1404         return 0;
1405
1406 error:
1407         if (sock)
1408                 tcp_sock_release(sock);
1409         *sockp = NULL;
1410         return err;
1411 }
1412
1413 static void schedule_clean_percpu(void)
1414 {
1415         schedule_delayed_work(&clean_percpu_wq, CLEAN_PERCPU_INTERVAL);
1416 }
1417
1418 static void clean_percpu(struct work_struct *work)
1419 {
1420         int i;
1421
1422         for_each_possible_cpu(i) {
1423                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1424                 int j;
1425
1426                 for (j = 0; j < FRAG_HASH_ENTRIES; j++) {
1427                         struct pkt_frag *frag;
1428
1429                         frag = flex_array_get(stt_percpu->frag_hash, j);
1430                         if (!frag->skbs ||
1431                             time_before(jiffies, frag->timestamp + FRAG_EXP_TIME))
1432                                 continue;
1433
1434                         spin_lock_bh(&stt_percpu->lock);
1435
1436                         if (frag->skbs &&
1437                             time_after(jiffies, frag->timestamp + FRAG_EXP_TIME))
1438                                 free_frag(stt_percpu, frag);
1439
1440                         spin_unlock_bh(&stt_percpu->lock);
1441                 }
1442         }
1443         schedule_clean_percpu();
1444 }
1445
1446 #ifdef HAVE_NF_HOOKFN_ARG_OPS
1447 #define FIRST_PARAM const struct nf_hook_ops *ops
1448 #else
1449 #define FIRST_PARAM unsigned int hooknum
1450 #endif
1451
1452 #ifdef HAVE_NF_HOOK_STATE
1453 #if RHEL_RELEASE_CODE > RHEL_RELEASE_VERSION(7,0)
1454 /* RHEL nfhook hacks. */
1455 #ifndef __GENKSYMS__
1456 #define LAST_PARAM const struct net_device *in, const struct net_device *out, \
1457                    const struct nf_hook_state *state
1458 #else
1459 #define LAST_PARAM const struct net_device *in, const struct net_device *out, \
1460                    int (*okfn)(struct sk_buff *)
1461 #endif
1462 #else
1463 #define LAST_PARAM const struct nf_hook_state *state
1464 #endif
1465 #else
1466 #define LAST_PARAM const struct net_device *in, const struct net_device *out, \
1467                    int (*okfn)(struct sk_buff *)
1468 #endif
1469
1470 static unsigned int nf_ip_hook(FIRST_PARAM, struct sk_buff *skb, LAST_PARAM)
1471 {
1472         struct stt_dev *stt_dev;
1473         int ip_hdr_len;
1474
1475         if (ip_hdr(skb)->protocol != IPPROTO_TCP)
1476                 return NF_ACCEPT;
1477
1478         ip_hdr_len = ip_hdrlen(skb);
1479         if (unlikely(!pskb_may_pull(skb, ip_hdr_len + sizeof(struct tcphdr))))
1480                 return NF_ACCEPT;
1481
1482         skb_set_transport_header(skb, ip_hdr_len);
1483
1484         stt_dev = stt_find_sock(dev_net(skb->dev), tcp_hdr(skb)->dest);
1485         if (!stt_dev)
1486                 return NF_ACCEPT;
1487
1488         __skb_pull(skb, ip_hdr_len + sizeof(struct tcphdr));
1489         stt_rcv(stt_dev, skb);
1490         return NF_STOLEN;
1491 }
1492
1493 static struct nf_hook_ops nf_hook_ops __read_mostly = {
1494         .hook           = nf_ip_hook,
1495         .owner          = THIS_MODULE,
1496         .pf             = NFPROTO_IPV4,
1497         .hooknum        = NF_INET_LOCAL_IN,
1498         .priority       = INT_MAX,
1499 };
1500
1501 static int stt_start(struct net *net)
1502 {
1503         struct stt_net *sn = net_generic(net, stt_net_id);
1504         int err;
1505         int i;
1506
1507         if (n_tunnels) {
1508                 n_tunnels++;
1509                 return 0;
1510         }
1511         get_random_bytes(&frag_hash_seed, sizeof(u32));
1512
1513         stt_percpu_data = alloc_percpu(struct stt_percpu);
1514         if (!stt_percpu_data) {
1515                 err = -ENOMEM;
1516                 goto error;
1517         }
1518
1519         for_each_possible_cpu(i) {
1520                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1521                 struct flex_array *frag_hash;
1522
1523                 spin_lock_init(&stt_percpu->lock);
1524                 INIT_LIST_HEAD(&stt_percpu->frag_lru);
1525                 get_random_bytes(&per_cpu(pkt_seq_counter, i), sizeof(u32));
1526
1527                 frag_hash = flex_array_alloc(sizeof(struct pkt_frag),
1528                                              FRAG_HASH_ENTRIES,
1529                                              GFP_KERNEL | __GFP_ZERO);
1530                 if (!frag_hash) {
1531                         err = -ENOMEM;
1532                         goto free_percpu;
1533                 }
1534                 stt_percpu->frag_hash = frag_hash;
1535
1536                 err = flex_array_prealloc(stt_percpu->frag_hash, 0,
1537                                           FRAG_HASH_ENTRIES,
1538                                           GFP_KERNEL | __GFP_ZERO);
1539                 if (err)
1540                         goto free_percpu;
1541         }
1542         schedule_clean_percpu();
1543         n_tunnels++;
1544
1545         if (sn->n_tunnels) {
1546                 sn->n_tunnels++;
1547                 return 0;
1548         }
1549 #ifdef HAVE_NF_REGISTER_NET_HOOK
1550         /* On kernel which support per net nf-hook, nf_register_hook() takes
1551          * rtnl-lock, which results in dead lock in stt-dev-create. Therefore
1552          * use this new API.
1553          */
1554         err = nf_register_net_hook(net, &nf_hook_ops);
1555 #else
1556         err = nf_register_hook(&nf_hook_ops);
1557 #endif
1558         if (err)
1559                 goto free_percpu;
1560         sn->n_tunnels++;
1561         return 0;
1562
1563 free_percpu:
1564         for_each_possible_cpu(i) {
1565                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1566
1567                 if (stt_percpu->frag_hash)
1568                         flex_array_free(stt_percpu->frag_hash);
1569         }
1570
1571         free_percpu(stt_percpu_data);
1572
1573 error:
1574         return err;
1575 }
1576
1577 static void stt_cleanup(struct net *net)
1578 {
1579         struct stt_net *sn = net_generic(net, stt_net_id);
1580         int i;
1581
1582         sn->n_tunnels--;
1583         if (sn->n_tunnels)
1584                 goto out;
1585 #ifdef HAVE_NF_REGISTER_NET_HOOK
1586         nf_unregister_net_hook(net, &nf_hook_ops);
1587 #else
1588         nf_unregister_hook(&nf_hook_ops);
1589 #endif
1590
1591 out:
1592         n_tunnels--;
1593         if (n_tunnels)
1594                 return;
1595
1596         cancel_delayed_work_sync(&clean_percpu_wq);
1597         for_each_possible_cpu(i) {
1598                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1599                 int j;
1600
1601                 for (j = 0; j < FRAG_HASH_ENTRIES; j++) {
1602                         struct pkt_frag *frag;
1603
1604                         frag = flex_array_get(stt_percpu->frag_hash, j);
1605                         kfree_skb_list(frag->skbs);
1606                 }
1607
1608                 flex_array_free(stt_percpu->frag_hash);
1609         }
1610
1611         free_percpu(stt_percpu_data);
1612 }
1613
1614 static netdev_tx_t stt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1615 {
1616 #ifdef HAVE_METADATA_DST
1617         return ovs_stt_xmit(skb);
1618 #else
1619         /* Drop All packets coming from networking stack. OVS-CB is
1620          * not initialized for these packets.
1621          */
1622         dev_kfree_skb(skb);
1623         dev->stats.tx_dropped++;
1624         return NETDEV_TX_OK;
1625 #endif
1626 }
1627
1628 /* Setup stats when device is created */
1629 static int stt_init(struct net_device *dev)
1630 {
1631         dev->tstats = (typeof(dev->tstats)) netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
1632         if (!dev->tstats)
1633                 return -ENOMEM;
1634
1635         return 0;
1636 }
1637
1638 static void stt_uninit(struct net_device *dev)
1639 {
1640         free_percpu(dev->tstats);
1641 }
1642
1643 static int stt_open(struct net_device *dev)
1644 {
1645         struct stt_dev *stt = netdev_priv(dev);
1646         struct net *net = stt->net;
1647         int err;
1648
1649         err = stt_start(net);
1650         if (err)
1651                 return err;
1652
1653         err = tcp_sock_create4(net, stt->dst_port, &stt->sock);
1654         if (err)
1655                 return err;
1656         return 0;
1657 }
1658
1659 static int stt_stop(struct net_device *dev)
1660 {
1661         struct stt_dev *stt_dev = netdev_priv(dev);
1662         struct net *net = stt_dev->net;
1663
1664         tcp_sock_release(stt_dev->sock);
1665         stt_dev->sock = NULL;
1666         stt_cleanup(net);
1667         return 0;
1668 }
1669
1670 static const struct net_device_ops stt_netdev_ops = {
1671         .ndo_init               = stt_init,
1672         .ndo_uninit             = stt_uninit,
1673         .ndo_open               = stt_open,
1674         .ndo_stop               = stt_stop,
1675         .ndo_start_xmit         = stt_dev_xmit,
1676         .ndo_get_stats64        = ip_tunnel_get_stats64,
1677         .ndo_change_mtu         = eth_change_mtu,
1678         .ndo_validate_addr      = eth_validate_addr,
1679         .ndo_set_mac_address    = eth_mac_addr,
1680 };
1681
1682 static void stt_get_drvinfo(struct net_device *dev,
1683                 struct ethtool_drvinfo *drvinfo)
1684 {
1685         strlcpy(drvinfo->version, STT_NETDEV_VER, sizeof(drvinfo->version));
1686         strlcpy(drvinfo->driver, "stt", sizeof(drvinfo->driver));
1687 }
1688
1689 static const struct ethtool_ops stt_ethtool_ops = {
1690         .get_drvinfo    = stt_get_drvinfo,
1691         .get_link       = ethtool_op_get_link,
1692 };
1693
1694 /* Info for udev, that this is a virtual tunnel endpoint */
1695 static struct device_type stt_type = {
1696         .name = "stt",
1697 };
1698
1699 /* Initialize the device structure. */
1700 static void stt_setup(struct net_device *dev)
1701 {
1702         ether_setup(dev);
1703
1704         dev->netdev_ops = &stt_netdev_ops;
1705         dev->ethtool_ops = &stt_ethtool_ops;
1706         dev->destructor = free_netdev;
1707
1708         SET_NETDEV_DEVTYPE(dev, &stt_type);
1709
1710         dev->features    |= NETIF_F_LLTX | NETIF_F_NETNS_LOCAL;
1711         dev->features    |= NETIF_F_SG | NETIF_F_HW_CSUM;
1712         dev->features    |= NETIF_F_RXCSUM;
1713         dev->features    |= NETIF_F_GSO_SOFTWARE;
1714
1715         dev->hw_features |= NETIF_F_SG | NETIF_F_HW_CSUM | NETIF_F_RXCSUM;
1716         dev->hw_features |= NETIF_F_GSO_SOFTWARE;
1717
1718 #ifdef HAVE_METADATA_DST
1719         netif_keep_dst(dev);
1720 #endif
1721         dev->priv_flags |= IFF_LIVE_ADDR_CHANGE | IFF_NO_QUEUE;
1722         eth_hw_addr_random(dev);
1723 }
1724
1725 static const struct nla_policy stt_policy[IFLA_STT_MAX + 1] = {
1726         [IFLA_STT_PORT]              = { .type = NLA_U16 },
1727 };
1728
1729 static int stt_validate(struct nlattr *tb[], struct nlattr *data[])
1730 {
1731         if (tb[IFLA_ADDRESS]) {
1732                 if (nla_len(tb[IFLA_ADDRESS]) != ETH_ALEN)
1733                         return -EINVAL;
1734
1735                 if (!is_valid_ether_addr(nla_data(tb[IFLA_ADDRESS])))
1736                         return -EADDRNOTAVAIL;
1737         }
1738
1739         return 0;
1740 }
1741
1742 static struct stt_dev *find_dev(struct net *net, __be16 dst_port)
1743 {
1744         struct stt_net *sn = net_generic(net, stt_net_id);
1745         struct stt_dev *dev;
1746
1747         list_for_each_entry(dev, &sn->stt_list, next) {
1748                 if (dev->dst_port == dst_port)
1749                         return dev;
1750         }
1751         return NULL;
1752 }
1753
1754 static int stt_configure(struct net *net, struct net_device *dev,
1755                           __be16 dst_port)
1756 {
1757         struct stt_net *sn = net_generic(net, stt_net_id);
1758         struct stt_dev *stt = netdev_priv(dev);
1759         int err;
1760
1761         stt->net = net;
1762         stt->dev = dev;
1763
1764         stt->dst_port = dst_port;
1765
1766         if (find_dev(net, dst_port))
1767                 return -EBUSY;
1768
1769         err = register_netdevice(dev);
1770         if (err)
1771                 return err;
1772
1773         list_add_rcu(&stt->next, &sn->stt_list);
1774         return 0;
1775 }
1776
1777 static int stt_newlink(struct net *net, struct net_device *dev,
1778                 struct nlattr *tb[], struct nlattr *data[])
1779 {
1780         __be16 dst_port = htons(STT_DST_PORT);
1781
1782         if (data[IFLA_STT_PORT])
1783                 dst_port = nla_get_be16(data[IFLA_STT_PORT]);
1784
1785         return stt_configure(net, dev, dst_port);
1786 }
1787
1788 static void stt_dellink(struct net_device *dev, struct list_head *head)
1789 {
1790         struct stt_dev *stt = netdev_priv(dev);
1791
1792         list_del_rcu(&stt->next);
1793         unregister_netdevice_queue(dev, head);
1794 }
1795
1796 static size_t stt_get_size(const struct net_device *dev)
1797 {
1798         return nla_total_size(sizeof(__be32));  /* IFLA_STT_PORT */
1799 }
1800
1801 static int stt_fill_info(struct sk_buff *skb, const struct net_device *dev)
1802 {
1803         struct stt_dev *stt = netdev_priv(dev);
1804
1805         if (nla_put_be16(skb, IFLA_STT_PORT, stt->dst_port))
1806                 goto nla_put_failure;
1807
1808         return 0;
1809
1810 nla_put_failure:
1811         return -EMSGSIZE;
1812 }
1813
1814 static struct rtnl_link_ops stt_link_ops __read_mostly = {
1815         .kind           = "stt",
1816         .maxtype        = IFLA_STT_MAX,
1817         .policy         = stt_policy,
1818         .priv_size      = sizeof(struct stt_dev),
1819         .setup          = stt_setup,
1820         .validate       = stt_validate,
1821         .newlink        = stt_newlink,
1822         .dellink        = stt_dellink,
1823         .get_size       = stt_get_size,
1824         .fill_info      = stt_fill_info,
1825 };
1826
1827 struct net_device *ovs_stt_dev_create_fb(struct net *net, const char *name,
1828                                       u8 name_assign_type, u16 dst_port)
1829 {
1830         struct nlattr *tb[IFLA_MAX + 1];
1831         struct net_device *dev;
1832         int err;
1833
1834         memset(tb, 0, sizeof(tb));
1835         dev = rtnl_create_link(net, (char *) name, name_assign_type,
1836                         &stt_link_ops, tb);
1837         if (IS_ERR(dev))
1838                 return dev;
1839
1840         err = stt_configure(net, dev, htons(dst_port));
1841         if (err) {
1842                 free_netdev(dev);
1843                 return ERR_PTR(err);
1844         }
1845         return dev;
1846 }
1847 EXPORT_SYMBOL_GPL(ovs_stt_dev_create_fb);
1848
1849 static int stt_init_net(struct net *net)
1850 {
1851         struct stt_net *sn = net_generic(net, stt_net_id);
1852
1853         INIT_LIST_HEAD(&sn->stt_list);
1854         return 0;
1855 }
1856
1857 static void stt_exit_net(struct net *net)
1858 {
1859         struct stt_net *sn = net_generic(net, stt_net_id);
1860         struct stt_dev *stt, *next;
1861         struct net_device *dev, *aux;
1862         LIST_HEAD(list);
1863
1864         rtnl_lock();
1865
1866         /* gather any stt devices that were moved into this ns */
1867         for_each_netdev_safe(net, dev, aux)
1868                 if (dev->rtnl_link_ops == &stt_link_ops)
1869                         unregister_netdevice_queue(dev, &list);
1870
1871         list_for_each_entry_safe(stt, next, &sn->stt_list, next) {
1872                 /* If stt->dev is in the same netns, it was already added
1873                  * to the stt by the previous loop.
1874                  */
1875                 if (!net_eq(dev_net(stt->dev), net))
1876                         unregister_netdevice_queue(stt->dev, &list);
1877         }
1878
1879         /* unregister the devices gathered above */
1880         unregister_netdevice_many(&list);
1881         rtnl_unlock();
1882 }
1883
1884 static struct pernet_operations stt_net_ops = {
1885         .init = stt_init_net,
1886         .exit = stt_exit_net,
1887         .id   = &stt_net_id,
1888         .size = sizeof(struct stt_net),
1889 };
1890
1891 int stt_init_module(void)
1892 {
1893         int rc;
1894
1895         rc = register_pernet_subsys(&stt_net_ops);
1896         if (rc)
1897                 goto out1;
1898
1899         rc = rtnl_link_register(&stt_link_ops);
1900         if (rc)
1901                 goto out2;
1902
1903         pr_info("STT tunneling driver\n");
1904         return 0;
1905 out2:
1906         unregister_pernet_subsys(&stt_net_ops);
1907 out1:
1908         return rc;
1909 }
1910
1911 void stt_cleanup_module(void)
1912 {
1913         rtnl_link_unregister(&stt_link_ops);
1914         unregister_pernet_subsys(&stt_net_ops);
1915 }
1916 #endif