datapath: STT: Fix checksum handling.
[cascardo/ovs.git] / datapath / linux / compat / stt.c
1 /*
2  * Stateless TCP Tunnel (STT) vport.
3  *
4  * Copyright (c) 2015 Nicira, Inc.
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * as published by the Free Software Foundation; either version
9  * 2 of the License, or (at your option) any later version.
10  */
11
12 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
13 #include <asm/unaligned.h>
14
15 #include <linux/delay.h>
16 #include <linux/flex_array.h>
17 #include <linux/if.h>
18 #include <linux/if_vlan.h>
19 #include <linux/ip.h>
20 #include <linux/ipv6.h>
21 #include <linux/jhash.h>
22 #include <linux/list.h>
23 #include <linux/log2.h>
24 #include <linux/module.h>
25 #include <linux/net.h>
26 #include <linux/netfilter.h>
27 #include <linux/percpu.h>
28 #include <linux/skbuff.h>
29 #include <linux/tcp.h>
30 #include <linux/workqueue.h>
31
32 #include <net/dst_metadata.h>
33 #include <net/icmp.h>
34 #include <net/inet_ecn.h>
35 #include <net/ip.h>
36 #include <net/ip_tunnels.h>
37 #include <net/ip6_checksum.h>
38 #include <net/net_namespace.h>
39 #include <net/netns/generic.h>
40 #include <net/sock.h>
41 #include <net/stt.h>
42 #include <net/tcp.h>
43 #include <net/udp.h>
44
45 #include "gso.h"
46 #include "compat.h"
47
48 #define STT_NETDEV_VER  "0.1"
49 #define STT_DST_PORT 7471
50
51 #ifdef OVS_STT
52 #define STT_VER 0
53
54 /* @list: Per-net list of STT ports.
55  * @rcv: The callback is called on STT packet recv, STT reassembly can generate
56  * multiple packets, in this case first packet has tunnel outer header, rest
57  * of the packets are inner packet segments with no stt header.
58  * @rcv_data: user data.
59  * @sock: Fake TCP socket for the STT port.
60  */
61 struct stt_dev {
62         struct net_device       *dev;
63         struct net              *net;
64         struct list_head        next;
65         struct list_head        up_next;
66         struct socket           *sock;
67         __be16                  dst_port;
68 };
69
70 #define STT_CSUM_VERIFIED       BIT(0)
71 #define STT_CSUM_PARTIAL        BIT(1)
72 #define STT_PROTO_IPV4          BIT(2)
73 #define STT_PROTO_TCP           BIT(3)
74 #define STT_PROTO_TYPES         (STT_PROTO_IPV4 | STT_PROTO_TCP)
75
76 #define SUPPORTED_GSO_TYPES (SKB_GSO_TCPV4 | SKB_GSO_UDP | SKB_GSO_DODGY | \
77                              SKB_GSO_TCPV6)
78
79 /* The length and offset of a fragment are encoded in the sequence number.
80  * STT_SEQ_LEN_SHIFT is the left shift needed to store the length.
81  * STT_SEQ_OFFSET_MASK is the mask to extract the offset.
82  */
83 #define STT_SEQ_LEN_SHIFT 16
84 #define STT_SEQ_OFFSET_MASK (BIT(STT_SEQ_LEN_SHIFT) - 1)
85
86 /* The maximum amount of memory used to store packets waiting to be reassembled
87  * on a given CPU.  Once this threshold is exceeded we will begin freeing the
88  * least recently used fragments.
89  */
90 #define REASM_HI_THRESH (4 * 1024 * 1024)
91 /* The target for the high memory evictor.  Once we have exceeded
92  * REASM_HI_THRESH, we will continue freeing fragments until we hit
93  * this limit.
94  */
95 #define REASM_LO_THRESH (3 * 1024 * 1024)
96 /* The length of time a given packet has to be reassembled from the time the
97  * first fragment arrives.  Once this limit is exceeded it becomes available
98  * for cleaning.
99  */
100 #define FRAG_EXP_TIME (30 * HZ)
101 /* Number of hash entries.  Each entry has only a single slot to hold a packet
102  * so if there are collisions, we will drop packets.  This is allocated
103  * per-cpu and each entry consists of struct pkt_frag.
104  */
105 #define FRAG_HASH_SHIFT         8
106 #define FRAG_HASH_ENTRIES       BIT(FRAG_HASH_SHIFT)
107 #define FRAG_HASH_SEGS          ((sizeof(u32) * 8) / FRAG_HASH_SHIFT)
108
109 #define CLEAN_PERCPU_INTERVAL (30 * HZ)
110
111 struct pkt_key {
112         __be32 saddr;
113         __be32 daddr;
114         __be32 pkt_seq;
115         u32 mark;
116 };
117
118 struct pkt_frag {
119         struct sk_buff *skbs;
120         unsigned long timestamp;
121         struct list_head lru_node;
122         struct pkt_key key;
123 };
124
125 struct stt_percpu {
126         struct flex_array *frag_hash;
127         struct list_head frag_lru;
128         unsigned int frag_mem_used;
129
130         /* Protect frags table. */
131         spinlock_t lock;
132 };
133
134 struct first_frag {
135         struct sk_buff *last_skb;
136         unsigned int mem_used;
137         u16 tot_len;
138         u16 rcvd_len;
139         bool set_ecn_ce;
140 };
141
142 struct frag_skb_cb {
143         u16 offset;
144
145         /* Only valid for the first skb in the chain. */
146         struct first_frag first;
147 };
148
149 #define FRAG_CB(skb) ((struct frag_skb_cb *)(skb)->cb)
150
151 /* per-network namespace private data for this module */
152 struct stt_net {
153         struct list_head stt_list;
154         struct list_head stt_up_list;   /* Devices which are in IFF_UP state. */
155         int n_tunnels;
156 #ifdef HAVE_NF_REGISTER_NET_HOOK
157         bool nf_hook_reg_done;
158 #endif
159 };
160
161 static int stt_net_id;
162
163 static struct stt_percpu __percpu *stt_percpu_data __read_mostly;
164 static u32 frag_hash_seed __read_mostly;
165
166 /* Protects sock-hash and refcounts. */
167 static DEFINE_MUTEX(stt_mutex);
168
169 static int n_tunnels;
170 static DEFINE_PER_CPU(u32, pkt_seq_counter);
171
172 static void clean_percpu(struct work_struct *work);
173 static DECLARE_DELAYED_WORK(clean_percpu_wq, clean_percpu);
174
175 static struct stt_dev *stt_find_up_dev(struct net *net, __be16 port)
176 {
177         struct stt_net *sn = net_generic(net, stt_net_id);
178         struct stt_dev *stt_dev;
179
180         list_for_each_entry_rcu(stt_dev, &sn->stt_up_list, up_next) {
181                 if (stt_dev->dst_port == port)
182                         return stt_dev;
183         }
184         return NULL;
185 }
186
187 static __be32 ack_seq(void)
188 {
189 #if NR_CPUS <= 65536
190         u32 pkt_seq, ack;
191
192         pkt_seq = this_cpu_read(pkt_seq_counter);
193         ack = pkt_seq << ilog2(NR_CPUS) | smp_processor_id();
194         this_cpu_inc(pkt_seq_counter);
195
196         return (__force __be32)ack;
197 #else
198 #error "Support for greater than 64k CPUs not implemented"
199 #endif
200 }
201
202 static int clear_gso(struct sk_buff *skb)
203 {
204         struct skb_shared_info *shinfo = skb_shinfo(skb);
205         int err;
206
207         if (shinfo->gso_type == 0 && shinfo->gso_size == 0 &&
208             shinfo->gso_segs == 0)
209                 return 0;
210
211         err = skb_unclone(skb, GFP_ATOMIC);
212         if (unlikely(err))
213                 return err;
214
215         shinfo = skb_shinfo(skb);
216         shinfo->gso_type = 0;
217         shinfo->gso_size = 0;
218         shinfo->gso_segs = 0;
219         return 0;
220 }
221
222 static struct sk_buff *normalize_frag_list(struct sk_buff *head,
223                                            struct sk_buff **skbp)
224 {
225         struct sk_buff *skb = *skbp;
226         struct sk_buff *last;
227
228         do {
229                 struct sk_buff *frags;
230
231                 if (skb_shared(skb)) {
232                         struct sk_buff *nskb = skb_clone(skb, GFP_ATOMIC);
233
234                         if (unlikely(!nskb))
235                                 return ERR_PTR(-ENOMEM);
236
237                         nskb->next = skb->next;
238                         consume_skb(skb);
239                         skb = nskb;
240                         *skbp = skb;
241                 }
242
243                 if (head) {
244                         head->len -= skb->len;
245                         head->data_len -= skb->len;
246                         head->truesize -= skb->truesize;
247                 }
248
249                 frags = skb_shinfo(skb)->frag_list;
250                 if (frags) {
251                         int err;
252
253                         err = skb_unclone(skb, GFP_ATOMIC);
254                         if (unlikely(err))
255                                 return ERR_PTR(err);
256
257                         last = normalize_frag_list(skb, &frags);
258                         if (IS_ERR(last))
259                                 return last;
260
261                         skb_shinfo(skb)->frag_list = NULL;
262                         last->next = skb->next;
263                         skb->next = frags;
264                 } else {
265                         last = skb;
266                 }
267
268                 skbp = &skb->next;
269         } while ((skb = skb->next));
270
271         return last;
272 }
273
274 /* Takes a linked list of skbs, which potentially contain frag_list
275  * (whose members in turn potentially contain frag_lists, etc.) and
276  * converts them into a single linear linked list.
277  */
278 static int straighten_frag_list(struct sk_buff **skbp)
279 {
280         struct sk_buff *err_skb;
281
282         err_skb = normalize_frag_list(NULL, skbp);
283         if (IS_ERR(err_skb))
284                 return PTR_ERR(err_skb);
285
286         return 0;
287 }
288
289 static void copy_skb_metadata(struct sk_buff *to, struct sk_buff *from)
290 {
291         to->protocol = from->protocol;
292         to->tstamp = from->tstamp;
293         to->priority = from->priority;
294         to->mark = from->mark;
295         to->vlan_tci = from->vlan_tci;
296 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
297         to->vlan_proto = from->vlan_proto;
298 #endif
299         skb_copy_secmark(to, from);
300 }
301
302 static void update_headers(struct sk_buff *skb, bool head,
303                                unsigned int l4_offset, unsigned int hdr_len,
304                                bool ipv4, u32 tcp_seq)
305 {
306         u16 old_len, new_len;
307         __be32 delta;
308         struct tcphdr *tcph;
309         int gso_size;
310
311         if (ipv4) {
312                 struct iphdr *iph = (struct iphdr *)(skb->data + ETH_HLEN);
313
314                 old_len = ntohs(iph->tot_len);
315                 new_len = skb->len - ETH_HLEN;
316                 iph->tot_len = htons(new_len);
317
318                 ip_send_check(iph);
319         } else {
320                 struct ipv6hdr *ip6h = (struct ipv6hdr *)(skb->data + ETH_HLEN);
321
322                 old_len = ntohs(ip6h->payload_len);
323                 new_len = skb->len - ETH_HLEN - sizeof(struct ipv6hdr);
324                 ip6h->payload_len = htons(new_len);
325         }
326
327         tcph = (struct tcphdr *)(skb->data + l4_offset);
328         if (!head) {
329                 tcph->seq = htonl(tcp_seq);
330                 tcph->cwr = 0;
331         }
332
333         if (skb->next) {
334                 tcph->fin = 0;
335                 tcph->psh = 0;
336         }
337
338         delta = htonl(~old_len + new_len);
339         tcph->check = ~csum_fold((__force __wsum)((__force u32)tcph->check +
340                                  (__force u32)delta));
341
342         gso_size = skb_shinfo(skb)->gso_size;
343         if (gso_size && skb->len - hdr_len <= gso_size)
344                 BUG_ON(clear_gso(skb));
345 }
346
347 static bool can_segment(struct sk_buff *head, bool ipv4, bool tcp, bool csum_partial)
348 {
349         /* If no offloading is in use then we don't have enough information
350          * to process the headers.
351          */
352         if (!csum_partial)
353                 goto linearize;
354
355         /* Handling UDP packets requires IP fragmentation, which means that
356          * the L4 checksum can no longer be calculated by hardware (since the
357          * fragments are in different packets.  If we have to compute the
358          * checksum it's faster just to linearize and large UDP packets are
359          * pretty uncommon anyways, so it's not worth dealing with for now.
360          */
361         if (!tcp)
362                 goto linearize;
363
364         if (ipv4) {
365                 struct iphdr *iph = (struct iphdr *)(head->data + ETH_HLEN);
366
367                 /* It's difficult to get the IP IDs exactly right here due to
368                  * varying segment sizes and potentially multiple layers of
369                  * segmentation.  IP ID isn't important when DF is set and DF
370                  * is generally set for TCP packets, so just linearize if it's
371                  * not.
372                  */
373                 if (!(iph->frag_off & htons(IP_DF)))
374                         goto linearize;
375         } else {
376                 struct ipv6hdr *ip6h = (struct ipv6hdr *)(head->data + ETH_HLEN);
377
378                 /* Jumbograms require more processing to update and we'll
379                  * probably never see them, so just linearize.
380                  */
381                 if (ip6h->payload_len == 0)
382                         goto linearize;
383         }
384         return true;
385
386 linearize:
387         return false;
388 }
389
390 static int copy_headers(struct sk_buff *head, struct sk_buff *frag,
391                             int hdr_len)
392 {
393         u16 csum_start;
394
395         if (skb_cloned(frag) || skb_headroom(frag) < hdr_len) {
396                 int extra_head = hdr_len - skb_headroom(frag);
397
398                 extra_head = extra_head > 0 ? extra_head : 0;
399                 if (unlikely(pskb_expand_head(frag, extra_head, 0,
400                                               GFP_ATOMIC)))
401                         return -ENOMEM;
402         }
403
404         memcpy(__skb_push(frag, hdr_len), head->data, hdr_len);
405
406         csum_start = head->csum_start - skb_headroom(head);
407         frag->csum_start = skb_headroom(frag) + csum_start;
408         frag->csum_offset = head->csum_offset;
409         frag->ip_summed = head->ip_summed;
410
411         skb_shinfo(frag)->gso_size = skb_shinfo(head)->gso_size;
412         skb_shinfo(frag)->gso_type = skb_shinfo(head)->gso_type;
413         skb_shinfo(frag)->gso_segs = 0;
414
415         copy_skb_metadata(frag, head);
416         return 0;
417 }
418
419 static int skb_list_segment(struct sk_buff *head, bool ipv4, int l4_offset)
420 {
421         struct sk_buff *skb;
422         struct tcphdr *tcph;
423         int seg_len;
424         int hdr_len;
425         int tcp_len;
426         u32 seq;
427
428         if (unlikely(!pskb_may_pull(head, l4_offset + sizeof(*tcph))))
429                 return -ENOMEM;
430
431         tcph = (struct tcphdr *)(head->data + l4_offset);
432         tcp_len = tcph->doff * 4;
433         hdr_len = l4_offset + tcp_len;
434
435         if (unlikely((tcp_len < sizeof(struct tcphdr)) ||
436                      (head->len < hdr_len)))
437                 return -EINVAL;
438
439         if (unlikely(!pskb_may_pull(head, hdr_len)))
440                 return -ENOMEM;
441
442         tcph = (struct tcphdr *)(head->data + l4_offset);
443         /* Update header of each segment. */
444         seq = ntohl(tcph->seq);
445         seg_len = skb_pagelen(head) - hdr_len;
446
447         skb = skb_shinfo(head)->frag_list;
448         skb_shinfo(head)->frag_list = NULL;
449         head->next = skb;
450         for (; skb; skb = skb->next) {
451                 int err;
452
453                 head->len -= skb->len;
454                 head->data_len -= skb->len;
455                 head->truesize -= skb->truesize;
456
457                 seq += seg_len;
458                 seg_len = skb->len;
459                 err = copy_headers(head, skb, hdr_len);
460                 if (err)
461                         return err;
462                 update_headers(skb, false, l4_offset, hdr_len, ipv4, seq);
463         }
464         update_headers(head, true, l4_offset, hdr_len, ipv4, 0);
465         return 0;
466 }
467
468 static int coalesce_skb(struct sk_buff **headp)
469 {
470         struct sk_buff *frag, *head, *prev;
471         int err;
472
473         err = straighten_frag_list(headp);
474         if (unlikely(err))
475                 return err;
476         head = *headp;
477
478         /* Coalesce frag list. */
479         prev = head;
480         for (frag = head->next; frag; frag = frag->next) {
481                 bool headstolen;
482                 int delta;
483
484                 if (unlikely(skb_unclone(prev, GFP_ATOMIC)))
485                         return -ENOMEM;
486
487                 if (!skb_try_coalesce(prev, frag, &headstolen, &delta)) {
488                         prev = frag;
489                         continue;
490                 }
491
492                 prev->next = frag->next;
493                 frag->len = 0;
494                 frag->data_len = 0;
495                 frag->truesize -= delta;
496                 kfree_skb_partial(frag, headstolen);
497                 frag = prev;
498         }
499
500         if (!head->next)
501                 return 0;
502
503         for (frag = head->next; frag; frag = frag->next) {
504                 head->len += frag->len;
505                 head->data_len += frag->len;
506                 head->truesize += frag->truesize;
507         }
508
509         skb_shinfo(head)->frag_list = head->next;
510         head->next = NULL;
511         return 0;
512 }
513
514 static int __try_to_segment(struct sk_buff *skb, bool csum_partial,
515                             bool ipv4, bool tcp, int l4_offset)
516 {
517         if (can_segment(skb, ipv4, tcp, csum_partial))
518                 return skb_list_segment(skb, ipv4, l4_offset);
519         else
520                 return skb_linearize(skb);
521 }
522
523 static int try_to_segment(struct sk_buff *skb)
524 {
525         struct stthdr *stth = stt_hdr(skb);
526         bool csum_partial = !!(stth->flags & STT_CSUM_PARTIAL);
527         bool ipv4 = !!(stth->flags & STT_PROTO_IPV4);
528         bool tcp = !!(stth->flags & STT_PROTO_TCP);
529         int l4_offset = stth->l4_offset;
530
531         return __try_to_segment(skb, csum_partial, ipv4, tcp, l4_offset);
532 }
533
534 static int segment_skb(struct sk_buff **headp, bool csum_partial,
535                        bool ipv4, bool tcp, int l4_offset)
536 {
537         int err;
538
539         err = coalesce_skb(headp);
540         if (err)
541                 return err;
542
543         if (skb_shinfo(*headp)->frag_list)
544                 return __try_to_segment(*headp, csum_partial,
545                                         ipv4, tcp, l4_offset);
546         return 0;
547 }
548
549 static int __push_stt_header(struct sk_buff *skb, __be64 tun_id,
550                              __be16 s_port, __be16 d_port,
551                              __be32 saddr, __be32 dst,
552                              __be16 l3_proto, u8 l4_proto,
553                              int dst_mtu)
554 {
555         int data_len = skb->len + sizeof(struct stthdr) + STT_ETH_PAD;
556         unsigned short encap_mss;
557         struct tcphdr *tcph;
558         struct stthdr *stth;
559
560         skb_push(skb, STT_HEADER_LEN);
561         skb_reset_transport_header(skb);
562         tcph = tcp_hdr(skb);
563         memset(tcph, 0, STT_HEADER_LEN);
564         stth = stt_hdr(skb);
565
566         if (skb->ip_summed == CHECKSUM_PARTIAL) {
567                 stth->flags |= STT_CSUM_PARTIAL;
568
569                 stth->l4_offset = skb->csum_start -
570                                         (skb_headroom(skb) +
571                                         STT_HEADER_LEN);
572
573                 if (l3_proto == htons(ETH_P_IP))
574                         stth->flags |= STT_PROTO_IPV4;
575
576                 if (l4_proto == IPPROTO_TCP)
577                         stth->flags |= STT_PROTO_TCP;
578
579                 stth->mss = htons(skb_shinfo(skb)->gso_size);
580         } else if (skb->ip_summed == CHECKSUM_UNNECESSARY) {
581                 stth->flags |= STT_CSUM_VERIFIED;
582         }
583
584         stth->vlan_tci = htons(skb->vlan_tci);
585         skb->vlan_tci = 0;
586         put_unaligned(tun_id, &stth->key);
587
588         tcph->source    = s_port;
589         tcph->dest      = d_port;
590         tcph->doff      = sizeof(struct tcphdr) / 4;
591         tcph->ack       = 1;
592         tcph->psh       = 1;
593         tcph->window    = htons(USHRT_MAX);
594         tcph->seq       = htonl(data_len << STT_SEQ_LEN_SHIFT);
595         tcph->ack_seq   = ack_seq();
596         tcph->check     = ~tcp_v4_check(skb->len, saddr, dst, 0);
597
598         skb->csum_start = skb_transport_header(skb) - skb->head;
599         skb->csum_offset = offsetof(struct tcphdr, check);
600         skb->ip_summed = CHECKSUM_PARTIAL;
601
602         encap_mss = dst_mtu - sizeof(struct iphdr) - sizeof(struct tcphdr);
603         if (data_len > encap_mss) {
604                 if (unlikely(skb_unclone(skb, GFP_ATOMIC)))
605                         return -EINVAL;
606
607                 skb_shinfo(skb)->gso_type = SKB_GSO_TCPV4;
608                 skb_shinfo(skb)->gso_size = encap_mss;
609                 skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(data_len, encap_mss);
610         } else {
611                 if (unlikely(clear_gso(skb)))
612                         return -EINVAL;
613         }
614         return 0;
615 }
616
617 static struct sk_buff *push_stt_header(struct sk_buff *head, __be64 tun_id,
618                                        __be16 s_port, __be16 d_port,
619                                        __be32 saddr, __be32 dst,
620                                        __be16 l3_proto, u8 l4_proto,
621                                        int dst_mtu)
622 {
623         struct sk_buff *skb;
624
625         if (skb_shinfo(head)->frag_list) {
626                 bool ipv4 = (l3_proto == htons(ETH_P_IP));
627                 bool tcp = (l4_proto == IPPROTO_TCP);
628                 bool csum_partial = (head->ip_summed == CHECKSUM_PARTIAL);
629                 int l4_offset = skb_transport_offset(head);
630
631                 /* Need to call skb_orphan() to report currect true-size.
632                  * calling skb_orphan() in this layer is odd but SKB with
633                  * frag-list should not be associated with any socket, so
634                  * skb-orphan should be no-op. */
635                 skb_orphan(head);
636                 if (unlikely(segment_skb(&head, csum_partial,
637                                          ipv4, tcp, l4_offset)))
638                         goto error;
639         }
640
641         for (skb = head; skb; skb = skb->next) {
642                 if (__push_stt_header(skb, tun_id, s_port, d_port, saddr, dst,
643                                       l3_proto, l4_proto, dst_mtu))
644                         goto error;
645         }
646
647         return head;
648 error:
649         kfree_skb_list(head);
650         return NULL;
651 }
652
653 static int stt_can_offload(struct sk_buff *skb, __be16 l3_proto, u8 l4_proto)
654 {
655         if (skb_is_gso(skb) && skb->ip_summed != CHECKSUM_PARTIAL) {
656                 int csum_offset;
657                 __sum16 *csum;
658                 int len;
659
660                 if (l4_proto == IPPROTO_TCP)
661                         csum_offset = offsetof(struct tcphdr, check);
662                 else if (l4_proto == IPPROTO_UDP)
663                         csum_offset = offsetof(struct udphdr, check);
664                 else
665                         return 0;
666
667                 len = skb->len - skb_transport_offset(skb);
668                 csum = (__sum16 *)(skb_transport_header(skb) + csum_offset);
669
670                 if (unlikely(!pskb_may_pull(skb, skb_transport_offset(skb) +
671                                                  csum_offset + sizeof(*csum))))
672                         return -EINVAL;
673
674                 if (l3_proto == htons(ETH_P_IP)) {
675                         struct iphdr *iph = ip_hdr(skb);
676
677                         *csum = ~csum_tcpudp_magic(iph->saddr, iph->daddr,
678                                                    len, l4_proto, 0);
679                 } else if (l3_proto == htons(ETH_P_IPV6)) {
680                         struct ipv6hdr *ip6h = ipv6_hdr(skb);
681
682                         *csum = ~csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
683                                                  len, l4_proto, 0);
684                 } else {
685                         return 0;
686                 }
687                 skb->csum_start = skb_transport_header(skb) - skb->head;
688                 skb->csum_offset = csum_offset;
689                 skb->ip_summed = CHECKSUM_PARTIAL;
690         }
691
692         if (skb->ip_summed == CHECKSUM_PARTIAL) {
693                 /* Assume receiver can only offload TCP/UDP over IPv4/6,
694                  * and require 802.1Q VLANs to be accelerated.
695                  */
696                 if (l3_proto != htons(ETH_P_IP) &&
697                     l3_proto != htons(ETH_P_IPV6))
698                         return 0;
699
700                 if (l4_proto != IPPROTO_TCP && l4_proto != IPPROTO_UDP)
701                         return 0;
702
703                 /* L4 offset must fit in a 1-byte field. */
704                 if (skb->csum_start - skb_headroom(skb) > 255)
705                         return 0;
706
707                 if (skb_shinfo(skb)->gso_type & ~SUPPORTED_GSO_TYPES)
708                         return 0;
709         }
710         /* Total size of encapsulated packet must fit in 16 bits. */
711         if (skb->len + STT_HEADER_LEN + sizeof(struct iphdr) > 65535)
712                 return 0;
713
714 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
715         if (skb_vlan_tag_present(skb) && skb->vlan_proto != htons(ETH_P_8021Q))
716                 return 0;
717 #endif
718         return 1;
719 }
720
721 static bool need_linearize(const struct sk_buff *skb)
722 {
723         struct skb_shared_info *shinfo = skb_shinfo(skb);
724         int i;
725
726         if (unlikely(shinfo->frag_list))
727                 return true;
728
729         /* Generally speaking we should linearize if there are paged frags.
730          * However, if all of the refcounts are 1 we know nobody else can
731          * change them from underneath us and we can skip the linearization.
732          */
733         for (i = 0; i < shinfo->nr_frags; i++)
734                 if (unlikely(page_count(skb_frag_page(&shinfo->frags[i])) > 1))
735                         return true;
736
737         return false;
738 }
739
740 static struct sk_buff *handle_offloads(struct sk_buff *skb, int min_headroom)
741 {
742         int err;
743
744 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
745         if (skb_vlan_tag_present(skb) && skb->vlan_proto != htons(ETH_P_8021Q)) {
746
747                 min_headroom += VLAN_HLEN;
748                 if (skb_headroom(skb) < min_headroom) {
749                         int head_delta = SKB_DATA_ALIGN(min_headroom -
750                                                         skb_headroom(skb) + 16);
751
752                         err = pskb_expand_head(skb, max_t(int, head_delta, 0),
753                                                0, GFP_ATOMIC);
754                         if (unlikely(err))
755                                 goto error;
756                 }
757
758                 skb = __vlan_hwaccel_push_inside(skb);
759                 if (!skb) {
760                         err = -ENOMEM;
761                         goto error;
762                 }
763         }
764 #endif
765
766         if (skb_is_gso(skb)) {
767                 struct sk_buff *nskb;
768                 char cb[sizeof(skb->cb)];
769
770                 memcpy(cb, skb->cb, sizeof(cb));
771
772                 nskb = __skb_gso_segment(skb, 0, false);
773                 if (IS_ERR(nskb)) {
774                         err = PTR_ERR(nskb);
775                         goto error;
776                 }
777
778                 consume_skb(skb);
779                 skb = nskb;
780                 while (nskb) {
781                         memcpy(nskb->cb, cb, sizeof(cb));
782                         nskb = nskb->next;
783                 }
784         } else if (skb->ip_summed == CHECKSUM_PARTIAL) {
785                 /* Pages aren't locked and could change at any time.
786                  * If this happens after we compute the checksum, the
787                  * checksum will be wrong.  We linearize now to avoid
788                  * this problem.
789                  */
790                 if (unlikely(need_linearize(skb))) {
791                         err = __skb_linearize(skb);
792                         if (unlikely(err))
793                                 goto error;
794                 }
795
796                 err = skb_checksum_help(skb);
797                 if (unlikely(err))
798                         goto error;
799         }
800         skb->ip_summed = CHECKSUM_NONE;
801
802         return skb;
803 error:
804         kfree_skb(skb);
805         return ERR_PTR(err);
806 }
807
808 static int skb_list_xmit(struct rtable *rt, struct sk_buff *skb, __be32 src,
809                          __be32 dst, __u8 tos, __u8 ttl, __be16 df)
810 {
811         int len = 0;
812
813         while (skb) {
814                 struct sk_buff *next = skb->next;
815
816                 if (next)
817                         dst_clone(&rt->dst);
818
819                 skb->next = NULL;
820                 len += iptunnel_xmit(NULL, rt, skb, src, dst, IPPROTO_TCP,
821                                      tos, ttl, df, false);
822
823                 skb = next;
824         }
825         return len;
826 }
827
828 static u8 parse_ipv6_l4_proto(struct sk_buff *skb)
829 {
830         unsigned int nh_ofs = skb_network_offset(skb);
831         int payload_ofs;
832         struct ipv6hdr *nh;
833         uint8_t nexthdr;
834         __be16 frag_off;
835
836         if (unlikely(!pskb_may_pull(skb, nh_ofs + sizeof(struct ipv6hdr))))
837                 return 0;
838
839         nh = ipv6_hdr(skb);
840         nexthdr = nh->nexthdr;
841         payload_ofs = (u8 *)(nh + 1) - skb->data;
842
843         payload_ofs = ipv6_skip_exthdr(skb, payload_ofs, &nexthdr, &frag_off);
844         if (unlikely(payload_ofs < 0))
845                 return 0;
846
847         return nexthdr;
848 }
849
850 static u8 skb_get_l4_proto(struct sk_buff *skb, __be16 l3_proto)
851 {
852         if (l3_proto == htons(ETH_P_IP)) {
853                 unsigned int nh_ofs = skb_network_offset(skb);
854
855                 if (unlikely(!pskb_may_pull(skb, nh_ofs + sizeof(struct iphdr))))
856                         return 0;
857
858                 return ip_hdr(skb)->protocol;
859         } else if (l3_proto == htons(ETH_P_IPV6)) {
860                 return parse_ipv6_l4_proto(skb);
861         }
862         return 0;
863 }
864
865 static int stt_xmit_skb(struct sk_buff *skb, struct rtable *rt,
866                  __be32 src, __be32 dst, __u8 tos,
867                  __u8 ttl, __be16 df, __be16 src_port, __be16 dst_port,
868                  __be64 tun_id)
869 {
870         struct ethhdr *eh = eth_hdr(skb);
871         int ret = 0, min_headroom;
872         __be16 inner_l3_proto;
873          u8 inner_l4_proto;
874
875         inner_l3_proto = eh->h_proto;
876         inner_l4_proto = skb_get_l4_proto(skb, inner_l3_proto);
877
878         min_headroom = LL_RESERVED_SPACE(rt->dst.dev) + rt->dst.header_len
879                         + STT_HEADER_LEN + sizeof(struct iphdr);
880
881         if (skb_headroom(skb) < min_headroom || skb_header_cloned(skb)) {
882                 int head_delta = SKB_DATA_ALIGN(min_headroom -
883                                                 skb_headroom(skb) +
884                                                 16);
885
886                 ret = pskb_expand_head(skb, max_t(int, head_delta, 0),
887                                        0, GFP_ATOMIC);
888                 if (unlikely(ret))
889                         goto err_free_rt;
890         }
891
892         ret = stt_can_offload(skb, inner_l3_proto, inner_l4_proto);
893         if (ret < 0)
894                 goto err_free_rt;
895         if (!ret) {
896                 skb = handle_offloads(skb, min_headroom);
897                 if (IS_ERR(skb)) {
898                         ret = PTR_ERR(skb);
899                         skb = NULL;
900                         goto err_free_rt;
901                 }
902         }
903
904         ret = 0;
905         while (skb) {
906                 struct sk_buff *next_skb = skb->next;
907
908                 skb->next = NULL;
909
910                 if (next_skb)
911                         dst_clone(&rt->dst);
912
913                 /* Push STT and TCP header. */
914                 skb = push_stt_header(skb, tun_id, src_port, dst_port, src,
915                                       dst, inner_l3_proto, inner_l4_proto,
916                                       dst_mtu(&rt->dst));
917                 if (unlikely(!skb)) {
918                         ip_rt_put(rt);
919                         goto next;
920                 }
921
922                 /* Push IP header. */
923                 ret += skb_list_xmit(rt, skb, src, dst, tos, ttl, df);
924
925 next:
926                 skb = next_skb;
927         }
928
929         return ret;
930
931 err_free_rt:
932         ip_rt_put(rt);
933         kfree_skb(skb);
934         return ret;
935 }
936
937 netdev_tx_t ovs_stt_xmit(struct sk_buff *skb)
938 {
939         struct net_device *dev = skb->dev;
940         struct stt_dev *stt_dev = netdev_priv(dev);
941         struct net *net = stt_dev->net;
942         __be16 dport = stt_dev->dst_port;
943         struct ip_tunnel_key *tun_key;
944         struct ip_tunnel_info *tun_info;
945         struct rtable *rt;
946         struct flowi4 fl;
947         __be16 sport;
948         __be16 df;
949         int err;
950
951         tun_info = skb_tunnel_info(skb);
952         if (unlikely(!tun_info)) {
953                 err = -EINVAL;
954                 goto error;
955         }
956
957         tun_key = &tun_info->key;
958
959         /* Route lookup */
960         memset(&fl, 0, sizeof(fl));
961         fl.daddr = tun_key->u.ipv4.dst;
962         fl.saddr = tun_key->u.ipv4.src;
963         fl.flowi4_tos = RT_TOS(tun_key->tos);
964         fl.flowi4_mark = skb->mark;
965         fl.flowi4_proto = IPPROTO_TCP;
966         rt = ip_route_output_key(net, &fl);
967         if (IS_ERR(rt)) {
968                 err = PTR_ERR(rt);
969                 goto error;
970         }
971
972         df = tun_key->tun_flags & TUNNEL_DONT_FRAGMENT ? htons(IP_DF) : 0;
973         sport = udp_flow_src_port(net, skb, 1, USHRT_MAX, true);
974         skb->ignore_df = 1;
975
976         err = stt_xmit_skb(skb, rt, fl.saddr, tun_key->u.ipv4.dst,
977                             tun_key->tos, tun_key->ttl,
978                             df, sport, dport, tun_key->tun_id);
979         iptunnel_xmit_stats(err, &dev->stats, (struct pcpu_sw_netstats __percpu *)dev->tstats);
980         return NETDEV_TX_OK;
981 error:
982         kfree_skb(skb);
983         dev->stats.tx_errors++;
984         return NETDEV_TX_OK;
985 }
986 EXPORT_SYMBOL(ovs_stt_xmit);
987
988 static void free_frag(struct stt_percpu *stt_percpu,
989                       struct pkt_frag *frag)
990 {
991         stt_percpu->frag_mem_used -= FRAG_CB(frag->skbs)->first.mem_used;
992         kfree_skb_list(frag->skbs);
993         list_del(&frag->lru_node);
994         frag->skbs = NULL;
995 }
996
997 static void evict_frags(struct stt_percpu *stt_percpu)
998 {
999         while (!list_empty(&stt_percpu->frag_lru) &&
1000                stt_percpu->frag_mem_used > REASM_LO_THRESH) {
1001                 struct pkt_frag *frag;
1002
1003                 frag = list_first_entry(&stt_percpu->frag_lru,
1004                                         struct pkt_frag,
1005                                         lru_node);
1006                 free_frag(stt_percpu, frag);
1007         }
1008 }
1009
1010 static bool pkt_key_match(struct net *net,
1011                           const struct pkt_frag *a, const struct pkt_key *b)
1012 {
1013         return a->key.saddr == b->saddr && a->key.daddr == b->daddr &&
1014                a->key.pkt_seq == b->pkt_seq && a->key.mark == b->mark &&
1015                net_eq(dev_net(a->skbs->dev), net);
1016 }
1017
1018 static u32 pkt_key_hash(const struct net *net, const struct pkt_key *key)
1019 {
1020         u32 initval = frag_hash_seed ^ (u32)(unsigned long)net ^ key->mark;
1021
1022         return jhash_3words((__force u32)key->saddr, (__force u32)key->daddr,
1023                             (__force u32)key->pkt_seq, initval);
1024 }
1025
1026 static struct pkt_frag *lookup_frag(struct net *net,
1027                                     struct stt_percpu *stt_percpu,
1028                                     const struct pkt_key *key, u32 hash)
1029 {
1030         struct pkt_frag *frag, *victim_frag = NULL;
1031         int i;
1032
1033         for (i = 0; i < FRAG_HASH_SEGS; i++) {
1034                 frag = flex_array_get(stt_percpu->frag_hash,
1035                                       hash & (FRAG_HASH_ENTRIES - 1));
1036
1037                 if (frag->skbs &&
1038                     time_before(jiffies, frag->timestamp + FRAG_EXP_TIME) &&
1039                     pkt_key_match(net, frag, key))
1040                         return frag;
1041
1042                 if (!victim_frag ||
1043                     (victim_frag->skbs &&
1044                      (!frag->skbs ||
1045                       time_before(frag->timestamp, victim_frag->timestamp))))
1046                         victim_frag = frag;
1047
1048                 hash >>= FRAG_HASH_SHIFT;
1049         }
1050
1051         if (victim_frag->skbs)
1052                 free_frag(stt_percpu, victim_frag);
1053
1054         return victim_frag;
1055 }
1056
1057 static struct sk_buff *reassemble(struct sk_buff *skb)
1058 {
1059         struct iphdr *iph = ip_hdr(skb);
1060         struct tcphdr *tcph = tcp_hdr(skb);
1061         u32 seq = ntohl(tcph->seq);
1062         struct stt_percpu *stt_percpu;
1063         struct sk_buff *last_skb;
1064         struct pkt_frag *frag;
1065         struct pkt_key key;
1066         int tot_len;
1067         u32 hash;
1068
1069         tot_len = seq >> STT_SEQ_LEN_SHIFT;
1070         FRAG_CB(skb)->offset = seq & STT_SEQ_OFFSET_MASK;
1071
1072         if (unlikely(skb->len == 0))
1073                 goto out_free;
1074
1075         if (unlikely(FRAG_CB(skb)->offset + skb->len > tot_len))
1076                 goto out_free;
1077
1078         if (tot_len == skb->len)
1079                 goto out;
1080
1081         key.saddr = iph->saddr;
1082         key.daddr = iph->daddr;
1083         key.pkt_seq = tcph->ack_seq;
1084         key.mark = skb->mark;
1085         hash = pkt_key_hash(dev_net(skb->dev), &key);
1086
1087         stt_percpu = per_cpu_ptr(stt_percpu_data, smp_processor_id());
1088
1089         spin_lock(&stt_percpu->lock);
1090
1091         if (unlikely(stt_percpu->frag_mem_used + skb->truesize > REASM_HI_THRESH))
1092                 evict_frags(stt_percpu);
1093
1094         frag = lookup_frag(dev_net(skb->dev), stt_percpu, &key, hash);
1095         if (!frag->skbs) {
1096                 frag->skbs = skb;
1097                 frag->key = key;
1098                 frag->timestamp = jiffies;
1099                 FRAG_CB(skb)->first.last_skb = skb;
1100                 FRAG_CB(skb)->first.mem_used = skb->truesize;
1101                 FRAG_CB(skb)->first.tot_len = tot_len;
1102                 FRAG_CB(skb)->first.rcvd_len = skb->len;
1103                 FRAG_CB(skb)->first.set_ecn_ce = false;
1104                 list_add_tail(&frag->lru_node, &stt_percpu->frag_lru);
1105                 stt_percpu->frag_mem_used += skb->truesize;
1106
1107                 skb = NULL;
1108                 goto unlock;
1109         }
1110
1111         /* Optimize for the common case where fragments are received in-order
1112          * and not overlapping.
1113          */
1114         last_skb = FRAG_CB(frag->skbs)->first.last_skb;
1115         if (likely(FRAG_CB(last_skb)->offset + last_skb->len ==
1116                    FRAG_CB(skb)->offset)) {
1117                 last_skb->next = skb;
1118                 FRAG_CB(frag->skbs)->first.last_skb = skb;
1119         } else {
1120                 struct sk_buff *prev = NULL, *next;
1121
1122                 for (next = frag->skbs; next; next = next->next) {
1123                         if (FRAG_CB(next)->offset >= FRAG_CB(skb)->offset)
1124                                 break;
1125                         prev = next;
1126                 }
1127
1128                 /* Overlapping fragments aren't allowed.  We shouldn't start
1129                  * before the end of the previous fragment.
1130                  */
1131                 if (prev &&
1132                     FRAG_CB(prev)->offset + prev->len > FRAG_CB(skb)->offset)
1133                         goto unlock_free;
1134
1135                 /* We also shouldn't end after the beginning of the next
1136                  * fragment.
1137                  */
1138                 if (next &&
1139                     FRAG_CB(skb)->offset + skb->len > FRAG_CB(next)->offset)
1140                         goto unlock_free;
1141
1142                 if (prev) {
1143                         prev->next = skb;
1144                 } else {
1145                         FRAG_CB(skb)->first = FRAG_CB(frag->skbs)->first;
1146                         frag->skbs = skb;
1147                 }
1148
1149                 if (next)
1150                         skb->next = next;
1151                 else
1152                         FRAG_CB(frag->skbs)->first.last_skb = skb;
1153         }
1154
1155         FRAG_CB(frag->skbs)->first.set_ecn_ce |= INET_ECN_is_ce(iph->tos);
1156         FRAG_CB(frag->skbs)->first.rcvd_len += skb->len;
1157         FRAG_CB(frag->skbs)->first.mem_used += skb->truesize;
1158         stt_percpu->frag_mem_used += skb->truesize;
1159
1160         if (FRAG_CB(frag->skbs)->first.tot_len ==
1161             FRAG_CB(frag->skbs)->first.rcvd_len) {
1162                 struct sk_buff *frag_head = frag->skbs;
1163
1164                 frag_head->tstamp = skb->tstamp;
1165                 if (FRAG_CB(frag_head)->first.set_ecn_ce)
1166                         INET_ECN_set_ce(frag_head);
1167
1168                 list_del(&frag->lru_node);
1169                 stt_percpu->frag_mem_used -= FRAG_CB(frag_head)->first.mem_used;
1170                 frag->skbs = NULL;
1171                 skb = frag_head;
1172         } else {
1173                 list_move_tail(&frag->lru_node, &stt_percpu->frag_lru);
1174                 skb = NULL;
1175         }
1176
1177         goto unlock;
1178
1179 unlock_free:
1180         kfree_skb(skb);
1181         skb = NULL;
1182 unlock:
1183         spin_unlock(&stt_percpu->lock);
1184         return skb;
1185 out_free:
1186         kfree_skb(skb);
1187         skb = NULL;
1188 out:
1189         return skb;
1190 }
1191
1192 static bool validate_checksum(struct sk_buff *skb)
1193 {
1194         struct iphdr *iph = ip_hdr(skb);
1195
1196         if (skb_csum_unnecessary(skb))
1197                 return true;
1198
1199         if (skb->ip_summed == CHECKSUM_COMPLETE &&
1200             !tcp_v4_check(skb->len, iph->saddr, iph->daddr, skb->csum))
1201                 return true;
1202
1203         skb->csum = csum_tcpudp_nofold(iph->saddr, iph->daddr, skb->len,
1204                                        IPPROTO_TCP, 0);
1205
1206         return __tcp_checksum_complete(skb) == 0;
1207 }
1208
1209 static bool set_offloads(struct sk_buff *skb)
1210 {
1211         struct stthdr *stth = stt_hdr(skb);
1212         unsigned short gso_type;
1213         int l3_header_size;
1214         int l4_header_size;
1215         u16 csum_offset;
1216         u8 proto_type;
1217
1218         if (stth->vlan_tci)
1219                 __vlan_hwaccel_put_tag(skb, htons(ETH_P_8021Q),
1220                                        ntohs(stth->vlan_tci));
1221
1222         if (!(stth->flags & STT_CSUM_PARTIAL)) {
1223                 if (stth->flags & STT_CSUM_VERIFIED)
1224                         skb->ip_summed = CHECKSUM_UNNECESSARY;
1225                 else
1226                         skb->ip_summed = CHECKSUM_NONE;
1227
1228                 return clear_gso(skb) == 0;
1229         }
1230
1231         proto_type = stth->flags & STT_PROTO_TYPES;
1232
1233         switch (proto_type) {
1234         case (STT_PROTO_IPV4 | STT_PROTO_TCP):
1235                 /* TCP/IPv4 */
1236                 csum_offset = offsetof(struct tcphdr, check);
1237                 gso_type = SKB_GSO_TCPV4;
1238                 l3_header_size = sizeof(struct iphdr);
1239                 l4_header_size = sizeof(struct tcphdr);
1240                 skb->protocol = htons(ETH_P_IP);
1241                 break;
1242         case STT_PROTO_TCP:
1243                 /* TCP/IPv6 */
1244                 csum_offset = offsetof(struct tcphdr, check);
1245                 gso_type = SKB_GSO_TCPV6;
1246                 l3_header_size = sizeof(struct ipv6hdr);
1247                 l4_header_size = sizeof(struct tcphdr);
1248                 skb->protocol = htons(ETH_P_IPV6);
1249                 break;
1250         case STT_PROTO_IPV4:
1251                 /* UDP/IPv4 */
1252                 csum_offset = offsetof(struct udphdr, check);
1253                 gso_type = SKB_GSO_UDP;
1254                 l3_header_size = sizeof(struct iphdr);
1255                 l4_header_size = sizeof(struct udphdr);
1256                 skb->protocol = htons(ETH_P_IP);
1257                 break;
1258         default:
1259                 /* UDP/IPv6 */
1260                 csum_offset = offsetof(struct udphdr, check);
1261                 gso_type = SKB_GSO_UDP;
1262                 l3_header_size = sizeof(struct ipv6hdr);
1263                 l4_header_size = sizeof(struct udphdr);
1264                 skb->protocol = htons(ETH_P_IPV6);
1265         }
1266
1267         if (unlikely(stth->l4_offset < ETH_HLEN + l3_header_size))
1268                 return false;
1269
1270         if (unlikely(!pskb_may_pull(skb, stth->l4_offset + l4_header_size)))
1271                 return false;
1272
1273         stth = stt_hdr(skb);
1274
1275         skb->csum_start = skb_headroom(skb) + stth->l4_offset;
1276         skb->csum_offset = csum_offset;
1277         skb->ip_summed = CHECKSUM_PARTIAL;
1278
1279         if (stth->mss) {
1280                 if (unlikely(skb_unclone(skb, GFP_ATOMIC)))
1281                         return false;
1282
1283                 skb_shinfo(skb)->gso_type = gso_type | SKB_GSO_DODGY;
1284                 skb_shinfo(skb)->gso_size = ntohs(stth->mss);
1285                 skb_shinfo(skb)->gso_segs = 0;
1286         } else {
1287                 if (unlikely(clear_gso(skb)))
1288                         return false;
1289         }
1290
1291         return true;
1292 }
1293
1294 static void rcv_list(struct net_device *dev, struct sk_buff *skb,
1295                      struct metadata_dst *tun_dst)
1296 {
1297         struct sk_buff *next;
1298
1299         do {
1300                 next = skb->next;
1301                 skb->next = NULL;
1302                 if (next) {
1303                         ovs_dst_hold((struct dst_entry *)tun_dst);
1304                         ovs_skb_dst_set(next, (struct dst_entry *)tun_dst);
1305                 }
1306                 ovs_ip_tunnel_rcv(dev, skb, tun_dst);
1307         } while ((skb = next));
1308 }
1309
1310 #ifndef HAVE_METADATA_DST
1311 static int __stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
1312 {
1313         struct metadata_dst tun_dst;
1314
1315         ovs_ip_tun_rx_dst(&tun_dst.u.tun_info, skb, TUNNEL_KEY | TUNNEL_CSUM,
1316                           get_unaligned(&stt_hdr(skb)->key), 0);
1317         tun_dst.u.tun_info.key.tp_src = tcp_hdr(skb)->source;
1318         tun_dst.u.tun_info.key.tp_dst = tcp_hdr(skb)->dest;
1319
1320         rcv_list(stt_dev->dev, skb, &tun_dst);
1321         return 0;
1322 }
1323 #else
1324 static int __stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
1325 {
1326         struct metadata_dst *tun_dst;
1327         __be16 flags;
1328         __be64 tun_id;
1329
1330         flags = TUNNEL_KEY | TUNNEL_CSUM;
1331         tun_id = get_unaligned(&stt_hdr(skb)->key);
1332         tun_dst = ip_tun_rx_dst(skb, flags, tun_id, 0);
1333         if (!tun_dst)
1334                 return -ENOMEM;
1335         tun_dst->u.tun_info.key.tp_src = tcp_hdr(skb)->source;
1336         tun_dst->u.tun_info.key.tp_dst = tcp_hdr(skb)->dest;
1337
1338         rcv_list(stt_dev->dev, skb, tun_dst);
1339         return 0;
1340 }
1341 #endif
1342
1343 static void stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
1344 {
1345         int err;
1346
1347         if (unlikely(!validate_checksum(skb)))
1348                 goto drop;
1349
1350         __skb_pull(skb, sizeof(struct tcphdr));
1351         skb = reassemble(skb);
1352         if (!skb)
1353                 return;
1354
1355         if (skb->next && coalesce_skb(&skb))
1356                 goto drop;
1357
1358         err = iptunnel_pull_header(skb,
1359                                    sizeof(struct stthdr) + STT_ETH_PAD,
1360                                    htons(ETH_P_TEB));
1361         if (unlikely(err))
1362                 goto drop;
1363
1364         if (unlikely(stt_hdr(skb)->version != 0))
1365                 goto drop;
1366
1367         if (unlikely(!set_offloads(skb)))
1368                 goto drop;
1369
1370         if (skb_shinfo(skb)->frag_list && try_to_segment(skb))
1371                 goto drop;
1372
1373         err = __stt_rcv(stt_dev, skb);
1374         if (err)
1375                 goto drop;
1376         return;
1377 drop:
1378         /* Consume bad packet */
1379         kfree_skb_list(skb);
1380         stt_dev->dev->stats.rx_errors++;
1381 }
1382
1383 static void tcp_sock_release(struct socket *sock)
1384 {
1385         kernel_sock_shutdown(sock, SHUT_RDWR);
1386         sock_release(sock);
1387 }
1388
1389 static int tcp_sock_create4(struct net *net, __be16 port,
1390                             struct socket **sockp)
1391 {
1392         struct sockaddr_in tcp_addr;
1393         struct socket *sock = NULL;
1394         int err;
1395
1396         err = sock_create_kern(net, AF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
1397         if (err < 0)
1398                 goto error;
1399
1400         memset(&tcp_addr, 0, sizeof(tcp_addr));
1401         tcp_addr.sin_family = AF_INET;
1402         tcp_addr.sin_addr.s_addr = htonl(INADDR_ANY);
1403         tcp_addr.sin_port = port;
1404         err = kernel_bind(sock, (struct sockaddr *)&tcp_addr,
1405                           sizeof(tcp_addr));
1406         if (err < 0)
1407                 goto error;
1408
1409         *sockp = sock;
1410         return 0;
1411
1412 error:
1413         if (sock)
1414                 tcp_sock_release(sock);
1415         *sockp = NULL;
1416         return err;
1417 }
1418
1419 static void schedule_clean_percpu(void)
1420 {
1421         schedule_delayed_work(&clean_percpu_wq, CLEAN_PERCPU_INTERVAL);
1422 }
1423
1424 static void clean_percpu(struct work_struct *work)
1425 {
1426         int i;
1427
1428         for_each_possible_cpu(i) {
1429                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1430                 int j;
1431
1432                 for (j = 0; j < FRAG_HASH_ENTRIES; j++) {
1433                         struct pkt_frag *frag;
1434
1435                         frag = flex_array_get(stt_percpu->frag_hash, j);
1436                         if (!frag->skbs ||
1437                             time_before(jiffies, frag->timestamp + FRAG_EXP_TIME))
1438                                 continue;
1439
1440                         spin_lock_bh(&stt_percpu->lock);
1441
1442                         if (frag->skbs &&
1443                             time_after(jiffies, frag->timestamp + FRAG_EXP_TIME))
1444                                 free_frag(stt_percpu, frag);
1445
1446                         spin_unlock_bh(&stt_percpu->lock);
1447                 }
1448         }
1449         schedule_clean_percpu();
1450 }
1451
1452 #ifdef HAVE_NF_HOOKFN_ARG_OPS
1453 #define FIRST_PARAM const struct nf_hook_ops *ops
1454 #else
1455 #define FIRST_PARAM unsigned int hooknum
1456 #endif
1457
1458 #ifdef HAVE_NF_HOOK_STATE
1459 #if RHEL_RELEASE_CODE > RHEL_RELEASE_VERSION(7,0)
1460 /* RHEL nfhook hacks. */
1461 #ifndef __GENKSYMS__
1462 #define LAST_PARAM const struct net_device *in, const struct net_device *out, \
1463                    const struct nf_hook_state *state
1464 #else
1465 #define LAST_PARAM const struct net_device *in, const struct net_device *out, \
1466                    int (*okfn)(struct sk_buff *)
1467 #endif
1468 #else
1469 #define LAST_PARAM const struct nf_hook_state *state
1470 #endif
1471 #else
1472 #define LAST_PARAM const struct net_device *in, const struct net_device *out, \
1473                    int (*okfn)(struct sk_buff *)
1474 #endif
1475
1476 static unsigned int nf_ip_hook(FIRST_PARAM, struct sk_buff *skb, LAST_PARAM)
1477 {
1478         struct stt_dev *stt_dev;
1479         int ip_hdr_len;
1480
1481         if (ip_hdr(skb)->protocol != IPPROTO_TCP)
1482                 return NF_ACCEPT;
1483
1484         ip_hdr_len = ip_hdrlen(skb);
1485         if (unlikely(!pskb_may_pull(skb, ip_hdr_len + sizeof(struct tcphdr))))
1486                 return NF_ACCEPT;
1487
1488         skb_set_transport_header(skb, ip_hdr_len);
1489
1490         stt_dev = stt_find_up_dev(dev_net(skb->dev), tcp_hdr(skb)->dest);
1491         if (!stt_dev)
1492                 return NF_ACCEPT;
1493
1494         __skb_pull(skb, ip_hdr_len);
1495         stt_rcv(stt_dev, skb);
1496         return NF_STOLEN;
1497 }
1498
1499 static struct nf_hook_ops nf_hook_ops __read_mostly = {
1500         .hook           = nf_ip_hook,
1501         .owner          = THIS_MODULE,
1502         .pf             = NFPROTO_IPV4,
1503         .hooknum        = NF_INET_LOCAL_IN,
1504         .priority       = INT_MAX,
1505 };
1506
1507 static int stt_start(struct net *net)
1508 {
1509         struct stt_net *sn = net_generic(net, stt_net_id);
1510         int err;
1511         int i;
1512
1513         if (n_tunnels) {
1514                 n_tunnels++;
1515                 return 0;
1516         }
1517         get_random_bytes(&frag_hash_seed, sizeof(u32));
1518
1519         stt_percpu_data = alloc_percpu(struct stt_percpu);
1520         if (!stt_percpu_data) {
1521                 err = -ENOMEM;
1522                 goto error;
1523         }
1524
1525         for_each_possible_cpu(i) {
1526                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1527                 struct flex_array *frag_hash;
1528
1529                 spin_lock_init(&stt_percpu->lock);
1530                 INIT_LIST_HEAD(&stt_percpu->frag_lru);
1531                 get_random_bytes(&per_cpu(pkt_seq_counter, i), sizeof(u32));
1532
1533                 frag_hash = flex_array_alloc(sizeof(struct pkt_frag),
1534                                              FRAG_HASH_ENTRIES,
1535                                              GFP_KERNEL | __GFP_ZERO);
1536                 if (!frag_hash) {
1537                         err = -ENOMEM;
1538                         goto free_percpu;
1539                 }
1540                 stt_percpu->frag_hash = frag_hash;
1541
1542                 err = flex_array_prealloc(stt_percpu->frag_hash, 0,
1543                                           FRAG_HASH_ENTRIES,
1544                                           GFP_KERNEL | __GFP_ZERO);
1545                 if (err)
1546                         goto free_percpu;
1547         }
1548         schedule_clean_percpu();
1549         n_tunnels++;
1550
1551         if (sn->n_tunnels) {
1552                 sn->n_tunnels++;
1553                 return 0;
1554         }
1555 #ifdef HAVE_NF_REGISTER_NET_HOOK
1556         /* On kernel which support per net nf-hook, nf_register_hook() takes
1557          * rtnl-lock, which results in dead lock in stt-dev-create. Therefore
1558          * use this new API.
1559          */
1560
1561         if (sn->nf_hook_reg_done)
1562                 goto out;
1563
1564         err = nf_register_net_hook(net, &nf_hook_ops);
1565         if (!err)
1566                 sn->nf_hook_reg_done = true;
1567 #else
1568         /* Register STT only on very first STT device addition. */
1569         if (!list_empty(&nf_hook_ops.list))
1570                 goto out;
1571
1572         err = nf_register_hook(&nf_hook_ops);
1573 #endif
1574         if (err)
1575                 goto dec_n_tunnel;
1576 out:
1577         sn->n_tunnels++;
1578         return 0;
1579
1580 dec_n_tunnel:
1581         n_tunnels--;
1582 free_percpu:
1583         for_each_possible_cpu(i) {
1584                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1585
1586                 if (stt_percpu->frag_hash)
1587                         flex_array_free(stt_percpu->frag_hash);
1588         }
1589
1590         free_percpu(stt_percpu_data);
1591
1592 error:
1593         return err;
1594 }
1595
1596 static void stt_cleanup(struct net *net)
1597 {
1598         struct stt_net *sn = net_generic(net, stt_net_id);
1599         int i;
1600
1601         sn->n_tunnels--;
1602         if (sn->n_tunnels)
1603                 goto out;
1604 out:
1605         n_tunnels--;
1606         if (n_tunnels)
1607                 return;
1608
1609         cancel_delayed_work_sync(&clean_percpu_wq);
1610         for_each_possible_cpu(i) {
1611                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1612                 int j;
1613
1614                 for (j = 0; j < FRAG_HASH_ENTRIES; j++) {
1615                         struct pkt_frag *frag;
1616
1617                         frag = flex_array_get(stt_percpu->frag_hash, j);
1618                         kfree_skb_list(frag->skbs);
1619                 }
1620
1621                 flex_array_free(stt_percpu->frag_hash);
1622         }
1623
1624         free_percpu(stt_percpu_data);
1625 }
1626
1627 static netdev_tx_t stt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1628 {
1629 #ifdef HAVE_METADATA_DST
1630         return ovs_stt_xmit(skb);
1631 #else
1632         /* Drop All packets coming from networking stack. OVS-CB is
1633          * not initialized for these packets.
1634          */
1635         dev_kfree_skb(skb);
1636         dev->stats.tx_dropped++;
1637         return NETDEV_TX_OK;
1638 #endif
1639 }
1640
1641 /* Setup stats when device is created */
1642 static int stt_init(struct net_device *dev)
1643 {
1644         dev->tstats = (typeof(dev->tstats)) netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
1645         if (!dev->tstats)
1646                 return -ENOMEM;
1647
1648         return 0;
1649 }
1650
1651 static void stt_uninit(struct net_device *dev)
1652 {
1653         free_percpu(dev->tstats);
1654 }
1655
1656 static int stt_open(struct net_device *dev)
1657 {
1658         struct stt_dev *stt = netdev_priv(dev);
1659         struct net *net = stt->net;
1660         struct stt_net *sn = net_generic(net, stt_net_id);
1661         int err;
1662
1663         err = stt_start(net);
1664         if (err)
1665                 return err;
1666
1667         err = tcp_sock_create4(net, stt->dst_port, &stt->sock);
1668         if (err)
1669                 return err;
1670         list_add_rcu(&stt->up_next, &sn->stt_up_list);
1671         return 0;
1672 }
1673
1674 static int stt_stop(struct net_device *dev)
1675 {
1676         struct stt_dev *stt_dev = netdev_priv(dev);
1677         struct net *net = stt_dev->net;
1678
1679         list_del_rcu(&stt_dev->up_next);
1680         synchronize_net();
1681         tcp_sock_release(stt_dev->sock);
1682         stt_dev->sock = NULL;
1683         stt_cleanup(net);
1684         return 0;
1685 }
1686
1687 static int __stt_change_mtu(struct net_device *dev, int new_mtu, bool strict)
1688 {
1689         int max_mtu = IP_MAX_MTU - STT_HEADER_LEN - sizeof(struct iphdr)
1690                       - dev->hard_header_len;
1691
1692         if (new_mtu < 68)
1693                 return -EINVAL;
1694
1695         if (new_mtu > max_mtu) {
1696                 if (strict)
1697                         return -EINVAL;
1698
1699                 new_mtu = max_mtu;
1700         }
1701
1702         dev->mtu = new_mtu;
1703         return 0;
1704 }
1705
1706 static int stt_change_mtu(struct net_device *dev, int new_mtu)
1707 {
1708         return __stt_change_mtu(dev, new_mtu, true);
1709 }
1710
1711 static const struct net_device_ops stt_netdev_ops = {
1712         .ndo_init               = stt_init,
1713         .ndo_uninit             = stt_uninit,
1714         .ndo_open               = stt_open,
1715         .ndo_stop               = stt_stop,
1716         .ndo_start_xmit         = stt_dev_xmit,
1717         .ndo_get_stats64        = ip_tunnel_get_stats64,
1718         .ndo_change_mtu         = stt_change_mtu,
1719         .ndo_validate_addr      = eth_validate_addr,
1720         .ndo_set_mac_address    = eth_mac_addr,
1721 };
1722
1723 static void stt_get_drvinfo(struct net_device *dev,
1724                 struct ethtool_drvinfo *drvinfo)
1725 {
1726         strlcpy(drvinfo->version, STT_NETDEV_VER, sizeof(drvinfo->version));
1727         strlcpy(drvinfo->driver, "stt", sizeof(drvinfo->driver));
1728 }
1729
1730 static const struct ethtool_ops stt_ethtool_ops = {
1731         .get_drvinfo    = stt_get_drvinfo,
1732         .get_link       = ethtool_op_get_link,
1733 };
1734
1735 /* Info for udev, that this is a virtual tunnel endpoint */
1736 static struct device_type stt_type = {
1737         .name = "stt",
1738 };
1739
1740 /* Initialize the device structure. */
1741 static void stt_setup(struct net_device *dev)
1742 {
1743         ether_setup(dev);
1744
1745         dev->netdev_ops = &stt_netdev_ops;
1746         dev->ethtool_ops = &stt_ethtool_ops;
1747         dev->destructor = free_netdev;
1748
1749         SET_NETDEV_DEVTYPE(dev, &stt_type);
1750
1751         dev->features    |= NETIF_F_LLTX | NETIF_F_NETNS_LOCAL;
1752         dev->features    |= NETIF_F_SG | NETIF_F_HW_CSUM;
1753         dev->features    |= NETIF_F_RXCSUM;
1754         dev->features    |= NETIF_F_GSO_SOFTWARE;
1755
1756         dev->hw_features |= NETIF_F_SG | NETIF_F_HW_CSUM | NETIF_F_RXCSUM;
1757         dev->hw_features |= NETIF_F_GSO_SOFTWARE;
1758
1759 #ifdef HAVE_METADATA_DST
1760         netif_keep_dst(dev);
1761 #endif
1762         dev->priv_flags |= IFF_LIVE_ADDR_CHANGE | IFF_NO_QUEUE;
1763         eth_hw_addr_random(dev);
1764 }
1765
1766 static const struct nla_policy stt_policy[IFLA_STT_MAX + 1] = {
1767         [IFLA_STT_PORT]              = { .type = NLA_U16 },
1768 };
1769
1770 static int stt_validate(struct nlattr *tb[], struct nlattr *data[])
1771 {
1772         if (tb[IFLA_ADDRESS]) {
1773                 if (nla_len(tb[IFLA_ADDRESS]) != ETH_ALEN)
1774                         return -EINVAL;
1775
1776                 if (!is_valid_ether_addr(nla_data(tb[IFLA_ADDRESS])))
1777                         return -EADDRNOTAVAIL;
1778         }
1779
1780         return 0;
1781 }
1782
1783 static struct stt_dev *find_dev(struct net *net, __be16 dst_port)
1784 {
1785         struct stt_net *sn = net_generic(net, stt_net_id);
1786         struct stt_dev *dev;
1787
1788         list_for_each_entry(dev, &sn->stt_list, next) {
1789                 if (dev->dst_port == dst_port)
1790                         return dev;
1791         }
1792         return NULL;
1793 }
1794
1795 static int stt_configure(struct net *net, struct net_device *dev,
1796                           __be16 dst_port)
1797 {
1798         struct stt_net *sn = net_generic(net, stt_net_id);
1799         struct stt_dev *stt = netdev_priv(dev);
1800         int err;
1801
1802         stt->net = net;
1803         stt->dev = dev;
1804
1805         stt->dst_port = dst_port;
1806
1807         if (find_dev(net, dst_port))
1808                 return -EBUSY;
1809
1810         err = __stt_change_mtu(dev, IP_MAX_MTU, false);
1811         if (err)
1812                 return err;
1813
1814         err = register_netdevice(dev);
1815         if (err)
1816                 return err;
1817
1818         list_add(&stt->next, &sn->stt_list);
1819         return 0;
1820 }
1821
1822 static int stt_newlink(struct net *net, struct net_device *dev,
1823                 struct nlattr *tb[], struct nlattr *data[])
1824 {
1825         __be16 dst_port = htons(STT_DST_PORT);
1826
1827         if (data[IFLA_STT_PORT])
1828                 dst_port = nla_get_be16(data[IFLA_STT_PORT]);
1829
1830         return stt_configure(net, dev, dst_port);
1831 }
1832
1833 static void stt_dellink(struct net_device *dev, struct list_head *head)
1834 {
1835         struct stt_dev *stt = netdev_priv(dev);
1836
1837         list_del(&stt->next);
1838         unregister_netdevice_queue(dev, head);
1839 }
1840
1841 static size_t stt_get_size(const struct net_device *dev)
1842 {
1843         return nla_total_size(sizeof(__be32));  /* IFLA_STT_PORT */
1844 }
1845
1846 static int stt_fill_info(struct sk_buff *skb, const struct net_device *dev)
1847 {
1848         struct stt_dev *stt = netdev_priv(dev);
1849
1850         if (nla_put_be16(skb, IFLA_STT_PORT, stt->dst_port))
1851                 goto nla_put_failure;
1852
1853         return 0;
1854
1855 nla_put_failure:
1856         return -EMSGSIZE;
1857 }
1858
1859 static struct rtnl_link_ops stt_link_ops __read_mostly = {
1860         .kind           = "stt",
1861         .maxtype        = IFLA_STT_MAX,
1862         .policy         = stt_policy,
1863         .priv_size      = sizeof(struct stt_dev),
1864         .setup          = stt_setup,
1865         .validate       = stt_validate,
1866         .newlink        = stt_newlink,
1867         .dellink        = stt_dellink,
1868         .get_size       = stt_get_size,
1869         .fill_info      = stt_fill_info,
1870 };
1871
1872 struct net_device *ovs_stt_dev_create_fb(struct net *net, const char *name,
1873                                       u8 name_assign_type, u16 dst_port)
1874 {
1875         struct nlattr *tb[IFLA_MAX + 1];
1876         struct net_device *dev;
1877         int err;
1878
1879         memset(tb, 0, sizeof(tb));
1880         dev = rtnl_create_link(net, (char *) name, name_assign_type,
1881                         &stt_link_ops, tb);
1882         if (IS_ERR(dev))
1883                 return dev;
1884
1885         err = stt_configure(net, dev, htons(dst_port));
1886         if (err) {
1887                 free_netdev(dev);
1888                 return ERR_PTR(err);
1889         }
1890         return dev;
1891 }
1892 EXPORT_SYMBOL_GPL(ovs_stt_dev_create_fb);
1893
1894 static int stt_init_net(struct net *net)
1895 {
1896         struct stt_net *sn = net_generic(net, stt_net_id);
1897
1898         INIT_LIST_HEAD(&sn->stt_list);
1899         INIT_LIST_HEAD(&sn->stt_up_list);
1900 #ifdef HAVE_NF_REGISTER_NET_HOOK
1901         sn->nf_hook_reg_done = false;
1902 #endif
1903         return 0;
1904 }
1905
1906 static void stt_exit_net(struct net *net)
1907 {
1908         struct stt_net *sn = net_generic(net, stt_net_id);
1909         struct stt_dev *stt, *next;
1910         struct net_device *dev, *aux;
1911         LIST_HEAD(list);
1912
1913 #ifdef HAVE_NF_REGISTER_NET_HOOK
1914         /* Ideally this should be done from stt_stop(), But on some kernels
1915          * nf-unreg operation needs RTNL-lock, which can cause deallock.
1916          * So it is done from here. */
1917         if (sn->nf_hook_reg_done)
1918                 nf_unregister_net_hook(net, &nf_hook_ops);
1919 #endif
1920
1921         rtnl_lock();
1922
1923         /* gather any stt devices that were moved into this ns */
1924         for_each_netdev_safe(net, dev, aux)
1925                 if (dev->rtnl_link_ops == &stt_link_ops)
1926                         unregister_netdevice_queue(dev, &list);
1927
1928         list_for_each_entry_safe(stt, next, &sn->stt_list, next) {
1929                 /* If stt->dev is in the same netns, it was already added
1930                  * to the stt by the previous loop.
1931                  */
1932                 if (!net_eq(dev_net(stt->dev), net))
1933                         unregister_netdevice_queue(stt->dev, &list);
1934         }
1935
1936         /* unregister the devices gathered above */
1937         unregister_netdevice_many(&list);
1938         rtnl_unlock();
1939 }
1940
1941 static struct pernet_operations stt_net_ops = {
1942         .init = stt_init_net,
1943         .exit = stt_exit_net,
1944         .id   = &stt_net_id,
1945         .size = sizeof(struct stt_net),
1946 };
1947
1948 int stt_init_module(void)
1949 {
1950         int rc;
1951
1952         rc = register_pernet_subsys(&stt_net_ops);
1953         if (rc)
1954                 goto out1;
1955
1956         rc = rtnl_link_register(&stt_link_ops);
1957         if (rc)
1958                 goto out2;
1959
1960         INIT_LIST_HEAD(&nf_hook_ops.list);
1961         pr_info("STT tunneling driver\n");
1962         return 0;
1963 out2:
1964         unregister_pernet_subsys(&stt_net_ops);
1965 out1:
1966         return rc;
1967 }
1968
1969 void stt_cleanup_module(void)
1970 {
1971 #ifndef HAVE_NF_REGISTER_NET_HOOK
1972         if (!list_empty(&nf_hook_ops.list))
1973                 nf_unregister_hook(&nf_hook_ops);
1974 #endif
1975         rtnl_link_unregister(&stt_link_ops);
1976         unregister_pernet_subsys(&stt_net_ops);
1977 }
1978 #endif