datapath-windows: Removed memory barrier and master lock
[cascardo/ovs.git] / datapath-windows / ovsext / Checksum.c
1 /*
2  * Copyright (c) 2014 VMware, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "precomp.h"
18 #include "Checksum.h"
19 #include "Flow.h"
20
21 #ifdef OVS_DBG_MOD
22 #undef OVS_DBG_MOD
23 #endif
24 #define OVS_DBG_MOD OVS_DBG_CHECKSUM
25 #include "Debug.h"
26 #include "PacketParser.h"
27
28 #ifndef htons
29 #define htons(_x) (((UINT16)(_x) >> 8) + (((UINT16)(_x) << 8) & 0xff00))
30 #endif
31
32 #ifndef swap64
33 #define swap64(_x) ((((UINT64)(_x) >> 8) & 0x00ff00ff00ff00ff) + \
34                    (((UINT64)(_x) << 8) & 0xff00ff00ff00ff00))
35 #endif
36
37 #define fold64(_x)                             \
38      _x = ((_x) >> 32) + ((_x) & 0xffffffff);  \
39      _x = (UINT32)(((_x) >> 32) + (_x));       \
40      _x = ((_x) >> 16) + ((_x) & 0xffff);      \
41      _x = (UINT16)(((_x) >> 16) + (_x))
42
43 #define fold32(_x)                             \
44      _x = ((_x) >> 16) + ((_x) & 0xffff);      \
45      _x = (UINT16)(((_x) >> 16) + (_x))
46
47
48 /*
49  *----------------------------------------------------------------------------
50  * CalculateOnesComplement --
51  *
52  *  Given the start address and buffer length, calculate the 1's complement
53  *  This routine can be used when multiple buffers are used for a packets.
54  *
55  *  PLEASE NOTE, even though the last parameter is UINT64, but the assumption
56  *  is it will not overflowed after adding the extra data.
57  *     ------------------------------------------------
58  *
59  * Result:
60  *    As name indicate, the final data is not 1's complemnent
61  *----------------------------------------------------------------------------
62  */
63 UINT64
64 CalculateOnesComplement(UINT8 *start,
65                         UINT16 totalLength,
66                         UINT64 initial,
67                         BOOLEAN isEvenStart)
68 {
69     UINT64  sum = 0, val;
70     UINT64  *src = (UINT64 *)start;
71     union {
72         UINT32 val;
73         UINT8  b8[4];
74     } tmp;
75
76     while (totalLength > 7) {
77         val = *src;
78         sum += (val >> 32) + (val & 0xffffffff);
79         src++;
80         totalLength -= 8;
81     }
82     if (totalLength > 3) {
83         sum += *(UINT32 *)src;
84         src = (UINT64 *)((UINT8 *)src + 4);
85         totalLength -= 4;
86     }
87     start = (UINT8 *)src;
88     tmp.val = 0;
89     switch (totalLength) {
90     case 3:
91         tmp.b8[2] = start[2];
92     case 2:
93         tmp.b8[1] = start[1];
94     case 1:
95         tmp.b8[0] = start[0];
96         sum += tmp.val;
97     }
98     sum = (isEvenStart ? sum : swap64(sum)) + initial;
99     return sum;
100 }
101
102 /*
103  *----------------------------------------------------------------------------
104  * CalculateChecksum --
105  *
106  *   Given the start point, and length, calculate the checksum
107  *   as 1's complement of 1's comlement.
108  *
109  *   This assume the checksum field is initailized properly.
110  *
111  * Input Parameter:
112  *    ptr:  point to the data to be checksumed
113  *    totalLength: total length of the data
114  *    initial: inital value to remit the checksum. Please note this
115  *             value should be network byte order value.
116  *
117  *    The last parameter may be useful where you don't want to set
118  *    checksum field to zero, in that case you can pass ~checksum,
119  *    this is equivalent of set checksum field to zero.
120  *
121  * Result:
122  *    The result can be assigned to checksum field directly.
123  *----------------------------------------------------------------------------
124  */
125 UINT16
126 CalculateChecksum(UINT8 *ptr,
127                   UINT16 totalLength,
128                   UINT16 initial)
129 {
130     UINT64  sum = CalculateOnesComplement(ptr, totalLength, initial, TRUE);
131     fold64(sum);
132     return (UINT16)~sum;
133 }
134
135 /*
136  *----------------------------------------------------------------------------
137  * CopyAndCalculateOnesComplement --
138  *
139  *  Given the start address and buffer length, calculate the 1's complement
140  *  at same time, copt the data from src to dst.
141  *
142  *  This routine can be used when multiple buffers are used for a packets.
143  *
144  *  PLEASE NOTE, even though the last parameter is UINT64, but the assumption
145  *  is it will not overflowed after adding the extra data.
146  *     ------------------------------------------------
147  *
148  * Result:
149  *    As name indicate, the final data is not 1's complemnent
150  *----------------------------------------------------------------------------
151  */
152 UINT64
153 CopyAndCalculateOnesComplement(UINT8 *dst,
154                                UINT8 *src,
155                                UINT16 length,
156                                UINT64 initial,
157                                BOOLEAN isEvenStart)
158 {
159     UINT64  sum =0, val;
160     UINT64 *src64, *dst64;
161     union {
162         UINT32 val;
163         UINT8  b8[4];
164     } tmp;
165
166     src64 = (UINT64 *)src;
167     dst64 = (UINT64 *)dst;
168
169     while (length > 7) {
170         val = *src64;
171         *dst64 = val;
172         sum += (val >> 32) + (val & 0xffffffff);
173         src64++;
174         dst64++;
175         length -= 8;
176     }
177
178     if (length > 3) {
179         val = *(UINT32 *)src64;
180         *(UINT32 *)dst64 = (UINT32)val;
181         sum += (UINT32)val;
182         dst64 = (UINT64 *)((UINT8 *)dst64 + 4);
183         src64 = (UINT64 *)((UINT8 *)src64 + 4);
184         length -= 4;
185     }
186     src = (UINT8 *)src64;
187     dst = (UINT8 *)dst64;
188     tmp.val = 0;
189     switch (length) {
190     case 3:
191         dst[2] = src[2];
192         tmp.b8[2] = src[2];
193     case 2:
194         dst[1] = src[1];
195         tmp.b8[1] = src[1];
196     case 1:
197         dst[0] = src[0];
198         tmp.b8[0] = src[0];
199         sum += tmp.val;
200     }
201     sum = (isEvenStart ? sum : swap64(sum)) + initial;
202     return sum;
203 }
204
205 /*
206  *----------------------------------------------------------------------------
207  * CopyAndCalculateChecksum --
208  *
209  *  This is similar to CalculateChecksum, except it will also copy data to
210  *  destination address.
211  *----------------------------------------------------------------------------
212  */
213 UINT16
214 CopyAndCalculateChecksum(UINT8 *dst,
215                          UINT8 *src,
216                          UINT16 length,
217                          UINT16 initial)
218 {
219
220     UINT64  sum = CopyAndCalculateOnesComplement(dst, src, length, initial,
221                                                  TRUE);
222     fold64(sum);
223     return (UINT16)~sum;
224 }
225
226
227 /*
228  *----------------------------------------------------------------------------
229  * IPChecksum --
230  *
231  *   Give IP header, calculate the IP checksum.
232  *   We assume IP checksum field is initialized properly
233  *
234  *  Input Pramater:
235  *   ipHdr: IP header start point
236  *   length: IP header length (potentially include IP options)
237  *   initial: same as CalculateChecksum
238  *
239  *  Result:
240  *   The result is already 1's complement, so can be assigned
241  *   to checksum field directly
242  *----------------------------------------------------------------------------
243  */
244 UINT16
245 IPChecksum(UINT8 *ipHdr,
246            UINT16 length,
247            UINT16 initial)
248 {
249     UINT32 sum = initial;
250     UINT16 *ptr = (UINT16 *)ipHdr;
251     ASSERT((length & 0x3) == 0);
252     while (length > 1) {
253         sum += ptr[0];
254         ptr++;
255         length -= 2;
256     }
257     fold32(sum);
258     return (UINT16)~sum;
259 }
260
261 /*
262  *----------------------------------------------------------------------------
263  *  IPPseudoChecksum --
264  *
265  *   Give src and dst IP address, protocol value and total
266  *   upper layer length(not include IP header, but include
267  *   upller layer protocol header, for example it include
268  *   TCP header for TCP checksum), calculate the pseudo
269  *   checksum, please note this checksum is just 1's complement
270  *   addition.
271  *
272  *  Input Parameter:
273  *    src: please note it is in network byte order
274  *    dst: same as src
275  *    protocol: protocol value in IP header
276  *    totalLength: total length of upper layer data including
277  *          header.
278  *
279  *  Result:
280  *
281  *   This value should be put in TCP checksum field before
282  *   calculating TCP checksum using CalculateChecksum with
283  *   initial value of 0.
284  *----------------------------------------------------------------------------
285  */
286 UINT16
287 IPPseudoChecksum(UINT32 *src,
288                  UINT32 *dst,
289                  UINT8 protocol,
290                  UINT16 totalLength)
291 {
292     UINT32 sum = (UINT32)htons(totalLength) + htons(protocol);
293     sum += (*src >> 16) + (*src & 0xffff);
294     sum += (*dst >> 16) + (*dst & 0xffff);
295     fold32(sum);
296     return (UINT16)sum;
297 }
298
299 /*
300  *----------------------------------------------------------------------------
301  * IPv6PseudoChecksum --
302  *
303  *  Given IPv6 src and dst address, upper layer protocol and total
304  *  upper layer protocol data length including upper layer header
305  *  part, calculate the pseudo checksum for upper layer protocol
306  *  checksum.
307  *
308  *  please note this checksum is just 1's complement addition.
309  *
310  *  Input Parameter:
311  *    src:   src IPv6 address in network byte order
312  *    dst:   dst IPv6 address.
313  *    protocol: upper layer protocol
314  *    totalLength: total length of upper layer data. Please note this is
315  *         in host byte order.
316  *
317  *  Result:
318  *
319  *  Place in upper layer checksum field before calculate upper layer
320  *  checksum.
321  *----------------------------------------------------------------------------
322  */
323 UINT16
324 IPv6PseudoChecksum(UINT32 *src,
325                    UINT32 *dst,
326                    UINT8 protocol,
327                    UINT16 totalLength)
328 {
329     UINT64 sum = (UINT32)htons(totalLength) + htons(protocol);
330     sum += (UINT64)src[0] + src[1] + src[2] + src[3];
331     sum += (UINT64)dst[0] + dst[1] + dst[2] + dst[3];
332     fold64(sum);
333     return (UINT16)sum;
334 }
335
336 /*
337  *----------------------------------------------------------------------------
338  * ChecksumUpdate32 --
339  *
340  *  Given old checksum value (as it is in checksum field),
341  *  prev value of the relevant field in network byte order
342  *  new value of the relevant field in the network byte order
343  *  calculate the new checksum.
344  *  Please check relevant RFC for reference.
345  *
346  *  Input Pramater:
347  *     oldSum: old checksum value in checksum field
348  *     prev:   previous value of relevant 32 bit feld in network
349  *             byte order.
350  *     new:    new value of the relevant 32 bit field in network
351  *             byte order.
352  *
353  *  Result:
354  *     new checksum value to be placed in the checksum field.
355  *----------------------------------------------------------------------------
356  */
357 UINT16
358 ChecksumUpdate32(UINT16 oldSum,
359                  UINT32 prev,
360                  UINT32 newValue)
361 {
362     UINT32 sum = ~prev;
363     sum = (sum >> 16) + (sum & 0xffff);
364     sum += (newValue >> 16) + (newValue & 0xffff);
365     sum += (UINT16)~oldSum;
366     fold32(sum);
367     return (UINT16)~sum;
368 }
369
370
371 /*
372  *----------------------------------------------------------------------------
373  * ChecksumUpdate16 --
374  *
375  *  Given old checksum value (as it is in checksum field),
376  *  prev value of the relevant field in network byte order
377  *  new value of the relevant field in the network byte order
378  *  calculate the new checksum.
379  *  Please check relevant RFC for reference.
380  *
381  *  Input Pramater:
382  *     oldSum: old checksum value in checksum field
383  *     prev:   previous value of relevant 32 bit feld in network
384  *             byte order.
385  *     new:    new value of the relevant 32 bit field in network
386  *             byte order.
387  *
388  *  Result:
389  *     new checksum value to be placed in the checksum field.
390  *----------------------------------------------------------------------------
391  */
392 UINT16
393 ChecksumUpdate16(UINT16 oldSum,
394                  UINT16 prev,
395                  UINT16 newValue)
396 {
397     UINT32 sum = (UINT16)~oldSum;
398     sum += (UINT32)((UINT16)~prev) + newValue;
399     fold32(sum);
400     return (UINT16)~sum;
401 }
402
403 /*
404  *----------------------------------------------------------------------------
405  * CalculateChecksumNB --
406  *
407  * Calculates checksum over a length of bytes contained in an NB.
408  *
409  * nb           : NB which contains the packet bytes.
410  * csumDataLen  : Length of bytes to be checksummed.
411  * offset       : offset to the first bytes of the data stream to be
412  *                checksumed.
413  *
414  * Result:
415  *  return 0, if there is a failure.
416  *----------------------------------------------------------------------------
417  */
418 UINT16
419 CalculateChecksumNB(const PNET_BUFFER nb,
420                     UINT16 csumDataLen,
421                     UINT32 offset)
422 {
423     ULONG mdlLen;
424     UINT16 csLen;
425     PUCHAR src;
426     UINT64 csum = 0;
427     PMDL currentMdl;
428     ULONG firstMdlLen;
429     /* Running count of bytes in remainder of the MDLs including current. */
430     ULONG packetLen;
431
432     if ((nb == NULL) || (csumDataLen == 0)
433             || (offset >= NET_BUFFER_DATA_LENGTH(nb))
434             || (offset + csumDataLen > NET_BUFFER_DATA_LENGTH(nb))) {
435         OVS_LOG_ERROR("Invalid parameters - csum length %u, offset %u,"
436                 "pkt%s len %u", csumDataLen, offset, nb? "":"(null)",
437                 nb? NET_BUFFER_DATA_LENGTH(nb) : 0);
438         return 0;
439     }
440
441     currentMdl = NET_BUFFER_CURRENT_MDL(nb);
442     packetLen = NET_BUFFER_DATA_LENGTH(nb);
443     firstMdlLen =
444         MmGetMdlByteCount(currentMdl) - NET_BUFFER_CURRENT_MDL_OFFSET(nb);
445
446     firstMdlLen = MIN(firstMdlLen, packetLen);
447     if (offset < firstMdlLen) {
448         src = (PUCHAR) MmGetSystemAddressForMdlSafe(currentMdl, LowPagePriority);
449         if (!src) {
450             return 0;
451         }
452         src += (NET_BUFFER_CURRENT_MDL_OFFSET(nb) + offset);
453         mdlLen = firstMdlLen - offset;
454         packetLen -= firstMdlLen;
455         ASSERT((INT)packetLen >= 0);
456     } else {
457         offset -= firstMdlLen;
458         packetLen -= firstMdlLen;
459         ASSERT((INT)packetLen >= 0);
460         currentMdl = NDIS_MDL_LINKAGE(currentMdl);
461         mdlLen = MmGetMdlByteCount(currentMdl);
462         mdlLen = MIN(mdlLen, packetLen);
463
464         while (offset >= mdlLen) {
465             offset -= mdlLen;
466             packetLen -= mdlLen;
467             ASSERT((INT)packetLen >= 0);
468             currentMdl = NDIS_MDL_LINKAGE(currentMdl);
469             mdlLen = MmGetMdlByteCount(currentMdl);
470             mdlLen = MIN(mdlLen, packetLen);
471         }
472
473         src = (PUCHAR)MmGetSystemAddressForMdlSafe(currentMdl, LowPagePriority);
474         if (!src) {
475             return 0;
476         }
477
478         src += offset;
479         mdlLen -= offset;
480     }
481
482     while (csumDataLen && (currentMdl != NULL)) {
483         ASSERT(mdlLen < 65536);
484         csLen = MIN((UINT16) mdlLen, csumDataLen);
485         //XXX Not handling odd bytes yet.
486         ASSERT(((csLen & 0x1) == 0) || csumDataLen <= mdlLen);
487
488         csum = CalculateOnesComplement(src, csLen, csum, TRUE);
489         fold64(csum);
490
491         csumDataLen -= csLen;
492         currentMdl = NDIS_MDL_LINKAGE(currentMdl);
493         if (csumDataLen && currentMdl) {
494             src = MmGetSystemAddressForMdlSafe(currentMdl, LowPagePriority);
495             if (!src) {
496                 return 0;
497             }
498
499             mdlLen = MmGetMdlByteCount(currentMdl);
500             mdlLen = MIN(mdlLen, packetLen);
501             /* packetLen does not include the current MDL from here on. */
502             packetLen -= mdlLen;
503             ASSERT((INT)packetLen >= 0);
504         }
505     }
506
507     ASSERT(csumDataLen == 0);
508     ASSERT((csum & ~0xffff) == 0);
509     return (UINT16) ~csum;
510 }
511
512 /*
513  * --------------------------------------------------------------------------
514  * OvsValidateIPChecksum
515  * --------------------------------------------------------------------------
516  */
517 NDIS_STATUS
518 OvsValidateIPChecksum(PNET_BUFFER_LIST curNbl,
519                       POVS_PACKET_HDR_INFO hdrInfo)
520 {
521     NDIS_TCP_IP_CHECKSUM_NET_BUFFER_LIST_INFO csumInfo;
522     uint16_t checksum, hdrChecksum;
523     struct IPHdr ip_storage;
524     const IPHdr *ipHdr;
525
526     if (!hdrInfo->isIPv4) {
527         return NDIS_STATUS_SUCCESS;
528     }
529
530     /* First check if NIC has indicated checksum failure. */
531     csumInfo.Value = NET_BUFFER_LIST_INFO(curNbl,
532                                           TcpIpChecksumNetBufferListInfo);
533     if (csumInfo.Receive.IpChecksumFailed) {
534         return NDIS_STATUS_FAILURE;
535     }
536
537     /* Next, check if the NIC did not validate the RX checksum. */
538     if (!csumInfo.Receive.IpChecksumSucceeded) {
539         ipHdr = OvsGetIp(curNbl, hdrInfo->l3Offset, &ip_storage);
540         if (ipHdr) {
541             ip_storage = *ipHdr;
542             hdrChecksum = ipHdr->check;
543             ip_storage.check = 0;
544             checksum = IPChecksum((uint8 *)&ip_storage, ipHdr->ihl * 4, 0);
545             if (checksum != hdrChecksum) {
546                 return NDIS_STATUS_FAILURE;
547             }
548         }
549     }
550     return NDIS_STATUS_SUCCESS;
551 }
552
553 /*
554  *----------------------------------------------------------------------------
555  * OvsValidateUDPChecksum
556  *----------------------------------------------------------------------------
557  */
558 NDIS_STATUS
559 OvsValidateUDPChecksum(PNET_BUFFER_LIST curNbl, BOOLEAN udpCsumZero)
560 {
561     NDIS_TCP_IP_CHECKSUM_NET_BUFFER_LIST_INFO csumInfo;
562
563     csumInfo.Value = NET_BUFFER_LIST_INFO(curNbl, TcpIpChecksumNetBufferListInfo);
564
565     if (udpCsumZero) {
566         /* Zero is valid checksum. */
567         csumInfo.Receive.UdpChecksumFailed = 0;
568         NET_BUFFER_LIST_INFO(curNbl, TcpIpChecksumNetBufferListInfo) = csumInfo.Value;
569         return NDIS_STATUS_SUCCESS;
570     }
571
572     /* First check if NIC has indicated UDP checksum failure. */
573     if (csumInfo.Receive.UdpChecksumFailed) {
574         return NDIS_STATUS_INVALID_PACKET;
575     }
576
577     return NDIS_STATUS_SUCCESS;
578 }