Merge branch 'perf-urgent-for-linus' of git://git.kernel.org/pub/scm/linux/kernel...
[cascardo/linux.git] / fs / cifs / smb2misc.c
index 1c59070..389fb9f 100644 (file)
@@ -38,7 +38,7 @@ check_smb2_hdr(struct smb2_hdr *hdr, __u64 mid)
         * Make sure that this really is an SMB, that it is a response,
         * and that the message ids match.
         */
-       if ((*(__le32 *)hdr->ProtocolId == SMB2_PROTO_NUMBER) &&
+       if ((hdr->ProtocolId == SMB2_PROTO_NUMBER) &&
            (mid == wire_mid)) {
                if (hdr->Flags & SMB2_FLAGS_SERVER_TO_REDIR)
                        return 0;
@@ -50,9 +50,9 @@ check_smb2_hdr(struct smb2_hdr *hdr, __u64 mid)
                                cifs_dbg(VFS, "Received Request not response\n");
                }
        } else { /* bad signature or mid */
-               if (*(__le32 *)hdr->ProtocolId != SMB2_PROTO_NUMBER)
+               if (hdr->ProtocolId != SMB2_PROTO_NUMBER)
                        cifs_dbg(VFS, "Bad protocol string signature header %x\n",
-                                *(unsigned int *) hdr->ProtocolId);
+                                le32_to_cpu(hdr->ProtocolId));
                if (mid != wire_mid)
                        cifs_dbg(VFS, "Mids do not match: %llu and %llu\n",
                                 mid, wire_mid);
@@ -93,11 +93,11 @@ static const __le16 smb2_rsp_struct_sizes[NUMBER_OF_SMB2_COMMANDS] = {
 };
 
 int
-smb2_check_message(char *buf, unsigned int length)
+smb2_check_message(char *buf, unsigned int length, struct TCP_Server_Info *srvr)
 {
        struct smb2_hdr *hdr = (struct smb2_hdr *)buf;
        struct smb2_pdu *pdu = (struct smb2_pdu *)hdr;
-       __u64 mid = le64_to_cpu(hdr->MessageId);
+       __u64 mid;
        __u32 len = get_rfc1002_length(buf);
        __u32 clc_len;  /* calculated length */
        int command;
@@ -111,6 +111,30 @@ smb2_check_message(char *buf, unsigned int length)
         * ie Validate the wct via smb2_struct_sizes table above
         */
 
+       if (hdr->ProtocolId == SMB2_TRANSFORM_PROTO_NUM) {
+               struct smb2_transform_hdr *thdr =
+                       (struct smb2_transform_hdr *)buf;
+               struct cifs_ses *ses = NULL;
+               struct list_head *tmp;
+
+               /* decrypt frame now that it is completely read in */
+               spin_lock(&cifs_tcp_ses_lock);
+               list_for_each(tmp, &srvr->smb_ses_list) {
+                       ses = list_entry(tmp, struct cifs_ses, smb_ses_list);
+                       if (ses->Suid == thdr->SessionId)
+                               break;
+
+                       ses = NULL;
+               }
+               spin_unlock(&cifs_tcp_ses_lock);
+               if (ses == NULL) {
+                       cifs_dbg(VFS, "no decryption - session id not found\n");
+                       return 1;
+               }
+       }
+
+
+       mid = le64_to_cpu(hdr->MessageId);
        if (length < sizeof(struct smb2_pdu)) {
                if ((length >= sizeof(struct smb2_hdr)) && (hdr->Status != 0)) {
                        pdu->StructureSize2 = 0;
@@ -322,7 +346,7 @@ smb2_get_data_area_len(int *off, int *len, struct smb2_hdr *hdr)
 
        /* return pointer to beginning of data area, ie offset from SMB start */
        if ((*off != 0) && (*len != 0))
-               return (char *)(&hdr->ProtocolId[0]) + *off;
+               return (char *)(&hdr->ProtocolId) + *off;
        else
                return NULL;
 }