5d9b035e29e88e9cd0cea4f1f1a1a63f06d319cb
[cascardo/ovs.git] / datapath-windows / ovsext / Checksum.c
1 /*
2  * Copyright (c) 2014 VMware, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "precomp.h"
18 #include "Checksum.h"
19 #include "Flow.h"
20
21 #ifdef OVS_DBG_MOD
22 #undef OVS_DBG_MOD
23 #endif
24 #define OVS_DBG_MOD OVS_DBG_CHECKSUM
25 #include "Debug.h"
26 #include "PacketParser.h"
27
28 #ifndef htons
29 #define htons(_x) (((UINT16)(_x) >> 8) + (((UINT16)(_x) << 8) & 0xff00))
30 #endif
31
32 #ifndef swap64
33 #define swap64(_x) ((((UINT64)(_x) >> 8) & 0x00ff00ff00ff00ff) + \
34                    (((UINT64)(_x) << 8) & 0xff00ff00ff00ff00))
35 #endif
36
37 #define fold64(_x)                             \
38      _x = ((_x) >> 32) + ((_x) & 0xffffffff);  \
39      _x = (UINT32)(((_x) >> 32) + (_x));       \
40      _x = ((_x) >> 16) + ((_x) & 0xffff);      \
41      _x = (UINT16)(((_x) >> 16) + (_x))
42
43 #define fold32(_x)                             \
44      _x = ((_x) >> 16) + ((_x) & 0xffff);      \
45      _x = (UINT16)(((_x) >> 16) + (_x))
46
47
48 /*
49  *----------------------------------------------------------------------------
50  * CalculateOnesComplement --
51  *
52  *  Given the start address and buffer length, calculate the 1's complement
53  *  This routine can be used when multiple buffers are used for a packets.
54  *
55  *  PLEASE NOTE, even though the last parameter is UINT64, but the assumption
56  *  is it will not overflowed after adding the extra data.
57  *     ------------------------------------------------
58  *
59  * Result:
60  *    As name indicate, the final data is not 1's complemnent
61  *----------------------------------------------------------------------------
62  */
63 UINT64
64 CalculateOnesComplement(UINT8 *start,
65                         UINT16 totalLength,
66                         UINT64 initial,
67                         BOOLEAN isEvenStart)
68 {
69     UINT64  sum = 0, val;
70     UINT64  *src = (UINT64 *)start;
71     while (totalLength > 7) {
72         val = *src;
73         sum += val;
74         if (sum < val) sum++;
75         src++;
76         totalLength -= 8;
77     }
78
79     start = (UINT8 *)src;
80
81     if (totalLength > 3) {
82         UINT32 val = *(UINT32 *)start;
83         sum += val;
84         if (sum < val) sum++;
85         start += 4;
86         totalLength -= 4;
87     }
88
89     if (totalLength > 1) {
90         UINT16 val = *(UINT16 *)start;
91         sum += val;
92         if (sum < val) sum++;
93         start += 2;
94         totalLength -= 2;
95     }
96
97     if (totalLength > 0) {
98         UINT8 val = *start;
99         sum += val;
100         if (sum < val) sum++;
101         start += 1;
102         totalLength -= 1;
103     }
104     ASSERT(totalLength == 0);
105
106     if (!isEvenStart) {
107         sum = _byteswap_uint64(sum);
108     }
109
110     sum += initial;
111     if (sum < initial) sum++;
112
113     return sum;
114 }
115
116 /*
117  *----------------------------------------------------------------------------
118  * CalculateChecksum --
119  *
120  *   Given the start point, and length, calculate the checksum
121  *   as 1's complement of 1's comlement.
122  *
123  *   This assume the checksum field is initailized properly.
124  *
125  * Input Parameter:
126  *    ptr:  point to the data to be checksumed
127  *    totalLength: total length of the data
128  *    initial: inital value to remit the checksum. Please note this
129  *             value should be network byte order value.
130  *
131  *    The last parameter may be useful where you don't want to set
132  *    checksum field to zero, in that case you can pass ~checksum,
133  *    this is equivalent of set checksum field to zero.
134  *
135  * Result:
136  *    The result can be assigned to checksum field directly.
137  *----------------------------------------------------------------------------
138  */
139 UINT16
140 CalculateChecksum(UINT8 *ptr,
141                   UINT16 totalLength,
142                   UINT16 initial)
143 {
144     UINT64  sum = CalculateOnesComplement(ptr, totalLength, initial, TRUE);
145     fold64(sum);
146     return (UINT16)~sum;
147 }
148
149 /*
150  *----------------------------------------------------------------------------
151  * CopyAndCalculateOnesComplement --
152  *
153  *  Given the start address and buffer length, calculate the 1's complement
154  *  at same time, copt the data from src to dst.
155  *
156  *  This routine can be used when multiple buffers are used for a packets.
157  *
158  *  PLEASE NOTE, even though the last parameter is UINT64, but the assumption
159  *  is it will not overflowed after adding the extra data.
160  *     ------------------------------------------------
161  *
162  * Result:
163  *    As name indicate, the final data is not 1's complemnent
164  *----------------------------------------------------------------------------
165  */
166 UINT64
167 CopyAndCalculateOnesComplement(UINT8 *dst,
168                                UINT8 *src,
169                                UINT16 length,
170                                UINT64 initial,
171                                BOOLEAN isEvenStart)
172 {
173     UINT64  sum =0, val;
174     UINT64 *src64, *dst64;
175     union {
176         UINT32 val;
177         UINT8  b8[4];
178     } tmp;
179
180     src64 = (UINT64 *)src;
181     dst64 = (UINT64 *)dst;
182
183     while (length > 7) {
184         val = *src64;
185         *dst64 = val;
186         sum += (val >> 32) + (val & 0xffffffff);
187         src64++;
188         dst64++;
189         length -= 8;
190     }
191
192     if (length > 3) {
193         val = *(UINT32 *)src64;
194         *(UINT32 *)dst64 = (UINT32)val;
195         sum += (UINT32)val;
196         dst64 = (UINT64 *)((UINT8 *)dst64 + 4);
197         src64 = (UINT64 *)((UINT8 *)src64 + 4);
198         length -= 4;
199     }
200     src = (UINT8 *)src64;
201     dst = (UINT8 *)dst64;
202     tmp.val = 0;
203     switch (length) {
204     case 3:
205         dst[2] = src[2];
206         tmp.b8[2] = src[2];
207     case 2:
208         dst[1] = src[1];
209         tmp.b8[1] = src[1];
210     case 1:
211         dst[0] = src[0];
212         tmp.b8[0] = src[0];
213         sum += tmp.val;
214     }
215     sum = (isEvenStart ? sum : swap64(sum)) + initial;
216     return sum;
217 }
218
219 /*
220  *----------------------------------------------------------------------------
221  * CopyAndCalculateChecksum --
222  *
223  *  This is similar to CalculateChecksum, except it will also copy data to
224  *  destination address.
225  *----------------------------------------------------------------------------
226  */
227 UINT16
228 CopyAndCalculateChecksum(UINT8 *dst,
229                          UINT8 *src,
230                          UINT16 length,
231                          UINT16 initial)
232 {
233
234     UINT64  sum = CopyAndCalculateOnesComplement(dst, src, length, initial,
235                                                  TRUE);
236     fold64(sum);
237     return (UINT16)~sum;
238 }
239
240
241 /*
242  *----------------------------------------------------------------------------
243  * IPChecksum --
244  *
245  *   Give IP header, calculate the IP checksum.
246  *   We assume IP checksum field is initialized properly
247  *
248  *  Input Pramater:
249  *   ipHdr: IP header start point
250  *   length: IP header length (potentially include IP options)
251  *   initial: same as CalculateChecksum
252  *
253  *  Result:
254  *   The result is already 1's complement, so can be assigned
255  *   to checksum field directly
256  *----------------------------------------------------------------------------
257  */
258 UINT16
259 IPChecksum(UINT8 *ipHdr,
260            UINT16 length,
261            UINT16 initial)
262 {
263     UINT32 sum = initial;
264     UINT16 *ptr = (UINT16 *)ipHdr;
265     ASSERT((length & 0x3) == 0);
266     while (length > 1) {
267         sum += ptr[0];
268         ptr++;
269         length -= 2;
270     }
271     fold32(sum);
272     return (UINT16)~sum;
273 }
274
275 /*
276  *----------------------------------------------------------------------------
277  *  IPPseudoChecksum --
278  *
279  *   Give src and dst IP address, protocol value and total
280  *   upper layer length(not include IP header, but include
281  *   upller layer protocol header, for example it include
282  *   TCP header for TCP checksum), calculate the pseudo
283  *   checksum, please note this checksum is just 1's complement
284  *   addition.
285  *
286  *  Input Parameter:
287  *    src: please note it is in network byte order
288  *    dst: same as src
289  *    protocol: protocol value in IP header
290  *    totalLength: total length of upper layer data including
291  *          header.
292  *
293  *  Result:
294  *
295  *   This value should be put in TCP checksum field before
296  *   calculating TCP checksum using CalculateChecksum with
297  *   initial value of 0.
298  *----------------------------------------------------------------------------
299  */
300 UINT16
301 IPPseudoChecksum(UINT32 *src,
302                  UINT32 *dst,
303                  UINT8 protocol,
304                  UINT16 totalLength)
305 {
306     UINT32 sum = (UINT32)htons(totalLength) + htons(protocol);
307     sum += (*src >> 16) + (*src & 0xffff);
308     sum += (*dst >> 16) + (*dst & 0xffff);
309     fold32(sum);
310     return (UINT16)sum;
311 }
312
313 /*
314  *----------------------------------------------------------------------------
315  * IPv6PseudoChecksum --
316  *
317  *  Given IPv6 src and dst address, upper layer protocol and total
318  *  upper layer protocol data length including upper layer header
319  *  part, calculate the pseudo checksum for upper layer protocol
320  *  checksum.
321  *
322  *  please note this checksum is just 1's complement addition.
323  *
324  *  Input Parameter:
325  *    src:   src IPv6 address in network byte order
326  *    dst:   dst IPv6 address.
327  *    protocol: upper layer protocol
328  *    totalLength: total length of upper layer data. Please note this is
329  *         in host byte order.
330  *
331  *  Result:
332  *
333  *  Place in upper layer checksum field before calculate upper layer
334  *  checksum.
335  *----------------------------------------------------------------------------
336  */
337 UINT16
338 IPv6PseudoChecksum(UINT32 *src,
339                    UINT32 *dst,
340                    UINT8 protocol,
341                    UINT16 totalLength)
342 {
343     UINT64 sum = (UINT32)htons(totalLength) + htons(protocol);
344     sum += (UINT64)src[0] + src[1] + src[2] + src[3];
345     sum += (UINT64)dst[0] + dst[1] + dst[2] + dst[3];
346     fold64(sum);
347     return (UINT16)sum;
348 }
349
350 /*
351  *----------------------------------------------------------------------------
352  * ChecksumUpdate32 --
353  *
354  *  Given old checksum value (as it is in checksum field),
355  *  prev value of the relevant field in network byte order
356  *  new value of the relevant field in the network byte order
357  *  calculate the new checksum.
358  *  Please check relevant RFC for reference.
359  *
360  *  Input Pramater:
361  *     oldSum: old checksum value in checksum field
362  *     prev:   previous value of relevant 32 bit feld in network
363  *             byte order.
364  *     new:    new value of the relevant 32 bit field in network
365  *             byte order.
366  *
367  *  Result:
368  *     new checksum value to be placed in the checksum field.
369  *----------------------------------------------------------------------------
370  */
371 UINT16
372 ChecksumUpdate32(UINT16 oldSum,
373                  UINT32 prev,
374                  UINT32 newValue)
375 {
376     UINT32 sum = ~prev;
377     sum = (sum >> 16) + (sum & 0xffff);
378     sum += (newValue >> 16) + (newValue & 0xffff);
379     sum += (UINT16)~oldSum;
380     fold32(sum);
381     return (UINT16)~sum;
382 }
383
384
385 /*
386  *----------------------------------------------------------------------------
387  * ChecksumUpdate16 --
388  *
389  *  Given old checksum value (as it is in checksum field),
390  *  prev value of the relevant field in network byte order
391  *  new value of the relevant field in the network byte order
392  *  calculate the new checksum.
393  *  Please check relevant RFC for reference.
394  *
395  *  Input Pramater:
396  *     oldSum: old checksum value in checksum field
397  *     prev:   previous value of relevant 32 bit feld in network
398  *             byte order.
399  *     new:    new value of the relevant 32 bit field in network
400  *             byte order.
401  *
402  *  Result:
403  *     new checksum value to be placed in the checksum field.
404  *----------------------------------------------------------------------------
405  */
406 UINT16
407 ChecksumUpdate16(UINT16 oldSum,
408                  UINT16 prev,
409                  UINT16 newValue)
410 {
411     UINT32 sum = (UINT16)~oldSum;
412     sum += (UINT32)((UINT16)~prev) + newValue;
413     fold32(sum);
414     return (UINT16)~sum;
415 }
416
417 /*
418  *----------------------------------------------------------------------------
419  * CalculateChecksumNB --
420  *
421  * Calculates checksum over a length of bytes contained in an NB.
422  *
423  * nb           : NB which contains the packet bytes.
424  * csumDataLen  : Length of bytes to be checksummed.
425  * offset       : offset to the first bytes of the data stream to be
426  *                checksumed.
427  *
428  * Result:
429  *  return 0, if there is a failure.
430  *----------------------------------------------------------------------------
431  */
432 UINT16
433 CalculateChecksumNB(const PNET_BUFFER nb,
434                     UINT16 csumDataLen,
435                     UINT32 offset)
436 {
437     ULONG mdlLen;
438     UINT16 csLen;
439     PUCHAR src;
440     UINT64 csum = 0;
441     PMDL currentMdl;
442     ULONG firstMdlLen;
443     /* Running count of bytes in remainder of the MDLs including current. */
444     ULONG packetLen;
445     BOOLEAN swapEnd = 1 & csumDataLen;
446
447     if ((nb == NULL) || (csumDataLen == 0)
448             || (offset >= NET_BUFFER_DATA_LENGTH(nb))
449             || (offset + csumDataLen > NET_BUFFER_DATA_LENGTH(nb))) {
450         OVS_LOG_ERROR("Invalid parameters - csum length %u, offset %u,"
451                 "pkt%s len %u", csumDataLen, offset, nb? "":"(null)",
452                 nb? NET_BUFFER_DATA_LENGTH(nb) : 0);
453         return 0;
454     }
455
456     currentMdl = NET_BUFFER_CURRENT_MDL(nb);
457     packetLen = NET_BUFFER_DATA_LENGTH(nb);
458     firstMdlLen =
459         MmGetMdlByteCount(currentMdl) - NET_BUFFER_CURRENT_MDL_OFFSET(nb);
460
461     firstMdlLen = MIN(firstMdlLen, packetLen);
462     if (offset < firstMdlLen) {
463         src = (PUCHAR) MmGetSystemAddressForMdlSafe(currentMdl, LowPagePriority);
464         if (!src) {
465             return 0;
466         }
467         src += (NET_BUFFER_CURRENT_MDL_OFFSET(nb) + offset);
468         mdlLen = firstMdlLen - offset;
469         packetLen -= firstMdlLen;
470         ASSERT((INT)packetLen >= 0);
471     } else {
472         offset -= firstMdlLen;
473         packetLen -= firstMdlLen;
474         ASSERT((INT)packetLen >= 0);
475         currentMdl = NDIS_MDL_LINKAGE(currentMdl);
476         mdlLen = MmGetMdlByteCount(currentMdl);
477         mdlLen = MIN(mdlLen, packetLen);
478
479         while (offset >= mdlLen) {
480             offset -= mdlLen;
481             packetLen -= mdlLen;
482             ASSERT((INT)packetLen >= 0);
483             currentMdl = NDIS_MDL_LINKAGE(currentMdl);
484             mdlLen = MmGetMdlByteCount(currentMdl);
485             mdlLen = MIN(mdlLen, packetLen);
486         }
487
488         src = (PUCHAR)MmGetSystemAddressForMdlSafe(currentMdl, LowPagePriority);
489         if (!src) {
490             return 0;
491         }
492
493         src += offset;
494         mdlLen -= offset;
495     }
496
497     while (csumDataLen && (currentMdl != NULL)) {
498         ASSERT(mdlLen < 65536);
499         csLen = MIN((UINT16) mdlLen, csumDataLen);
500
501         csum = CalculateOnesComplement(src, csLen, csum, !(1 & csumDataLen));
502         fold64(csum);
503
504         csumDataLen -= csLen;
505         currentMdl = NDIS_MDL_LINKAGE(currentMdl);
506         if (csumDataLen && currentMdl) {
507             src = MmGetSystemAddressForMdlSafe(currentMdl, LowPagePriority);
508             if (!src) {
509                 return 0;
510             }
511
512             mdlLen = MmGetMdlByteCount(currentMdl);
513             mdlLen = MIN(mdlLen, packetLen);
514             /* packetLen does not include the current MDL from here on. */
515             packetLen -= mdlLen;
516             ASSERT((INT)packetLen >= 0);
517         }
518     }
519
520     fold64(csum);
521     ASSERT(csumDataLen == 0);
522     ASSERT((csum & ~0xffff) == 0);
523     csum = (UINT16)~csum;
524     if (swapEnd) {
525         return _byteswap_ushort((UINT16)csum);
526     }
527     return (UINT16)csum;
528 }
529
530 /*
531  * --------------------------------------------------------------------------
532  * OvsValidateIPChecksum
533  * --------------------------------------------------------------------------
534  */
535 NDIS_STATUS
536 OvsValidateIPChecksum(PNET_BUFFER_LIST curNbl,
537                       POVS_PACKET_HDR_INFO hdrInfo)
538 {
539     NDIS_TCP_IP_CHECKSUM_NET_BUFFER_LIST_INFO csumInfo;
540     uint16_t checksum, hdrChecksum;
541     struct IPHdr ip_storage;
542     const IPHdr *ipHdr;
543
544     if (!hdrInfo->isIPv4) {
545         return NDIS_STATUS_SUCCESS;
546     }
547
548     /* First check if NIC has indicated checksum failure. */
549     csumInfo.Value = NET_BUFFER_LIST_INFO(curNbl,
550                                           TcpIpChecksumNetBufferListInfo);
551     if (csumInfo.Receive.IpChecksumFailed) {
552         return NDIS_STATUS_FAILURE;
553     }
554
555     /* Next, check if the NIC did not validate the RX checksum. */
556     if (!csumInfo.Receive.IpChecksumSucceeded) {
557         ipHdr = OvsGetIp(curNbl, hdrInfo->l3Offset, &ip_storage);
558         if (ipHdr) {
559             ip_storage = *ipHdr;
560             hdrChecksum = ipHdr->check;
561             ip_storage.check = 0;
562             checksum = IPChecksum((uint8 *)&ip_storage, ipHdr->ihl * 4, 0);
563             if (checksum != hdrChecksum) {
564                 return NDIS_STATUS_FAILURE;
565             }
566         }
567     }
568     return NDIS_STATUS_SUCCESS;
569 }
570
571 /*
572  *----------------------------------------------------------------------------
573  * OvsValidateUDPChecksum
574  *----------------------------------------------------------------------------
575  */
576 NDIS_STATUS
577 OvsValidateUDPChecksum(PNET_BUFFER_LIST curNbl, BOOLEAN udpCsumZero)
578 {
579     NDIS_TCP_IP_CHECKSUM_NET_BUFFER_LIST_INFO csumInfo;
580
581     csumInfo.Value = NET_BUFFER_LIST_INFO(curNbl, TcpIpChecksumNetBufferListInfo);
582
583     if (udpCsumZero) {
584         /* Zero is valid checksum. */
585         csumInfo.Receive.UdpChecksumFailed = 0;
586         NET_BUFFER_LIST_INFO(curNbl, TcpIpChecksumNetBufferListInfo) = csumInfo.Value;
587         return NDIS_STATUS_SUCCESS;
588     }
589
590     /* First check if NIC has indicated UDP checksum failure. */
591     if (csumInfo.Receive.UdpChecksumFailed) {
592         return NDIS_STATUS_INVALID_PACKET;
593     }
594
595     return NDIS_STATUS_SUCCESS;
596 }