openvswitch: __nf_ct_l{3,4}proto_find() always return a valid pointer
[cascardo/linux.git] / net / openvswitch / conntrack.c
index dc5eb29..9f0bc49 100644 (file)
@@ -367,6 +367,7 @@ static int handle_fragments(struct net *net, struct sw_flow_key *key,
        } else if (key->eth.type == htons(ETH_P_IPV6)) {
                enum ip6_defrag_users user = IP6_DEFRAG_CONNTRACK_IN + zone;
 
+               skb_orphan(skb);
                memset(IP6CB(skb), 0, sizeof(struct inet6_skb_parm));
                err = nf_ct_frag6_gather(net, skb, user);
                if (err)
@@ -438,20 +439,12 @@ ovs_ct_find_existing(struct net *net, const struct nf_conntrack_zone *zone,
        u8 protonum;
 
        l3proto = __nf_ct_l3proto_find(l3num);
-       if (!l3proto) {
-               pr_debug("ovs_ct_find_existing: Can't get l3proto\n");
-               return NULL;
-       }
        if (l3proto->get_l4proto(skb, skb_network_offset(skb), &dataoff,
                                 &protonum) <= 0) {
                pr_debug("ovs_ct_find_existing: Can't get protonum\n");
                return NULL;
        }
        l4proto = __nf_ct_l4proto_find(l3num, protonum);
-       if (!l4proto) {
-               pr_debug("ovs_ct_find_existing: Can't get l4proto\n");
-               return NULL;
-       }
        if (!nf_ct_get_tuple(skb, skb_network_offset(skb), dataoff, l3num,
                             protonum, net, &tuple, l3proto, l4proto)) {
                pr_debug("ovs_ct_find_existing: Can't get tuple\n");
@@ -535,14 +528,15 @@ static int ovs_ct_nat_execute(struct sk_buff *skb, struct nf_conn *ct,
        switch (ctinfo) {
        case IP_CT_RELATED:
        case IP_CT_RELATED_REPLY:
-               if (skb->protocol == htons(ETH_P_IP) &&
+               if (IS_ENABLED(CONFIG_NF_NAT_IPV4) &&
+                   skb->protocol == htons(ETH_P_IP) &&
                    ip_hdr(skb)->protocol == IPPROTO_ICMP) {
                        if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo,
                                                           hooknum))
                                err = NF_DROP;
                        goto push;
-#if IS_ENABLED(CONFIG_NF_NAT_IPV6)
-               } else if (skb->protocol == htons(ETH_P_IPV6)) {
+               } else if (IS_ENABLED(CONFIG_NF_NAT_IPV6) &&
+                          skb->protocol == htons(ETH_P_IPV6)) {
                        __be16 frag_off;
                        u8 nexthdr = ipv6_hdr(skb)->nexthdr;
                        int hdrlen = ipv6_skip_exthdr(skb,
@@ -557,7 +551,6 @@ static int ovs_ct_nat_execute(struct sk_buff *skb, struct nf_conn *ct,
                                        err = NF_DROP;
                                goto push;
                        }
-#endif
                }
                /* Non-ICMP, fall thru to initialize if needed. */
        case IP_CT_NEW:
@@ -664,11 +657,12 @@ static int ovs_ct_nat(struct net *net, struct sw_flow_key *key,
 
        /* Determine NAT type.
         * Check if the NAT type can be deduced from the tracked connection.
-        * Make sure expected traffic is NATted only when committing.
+        * Make sure new expected connections (IP_CT_RELATED) are NATted only
+        * when committing.
         */
        if (info->nat & OVS_CT_NAT && ctinfo != IP_CT_NEW &&
            ct->status & IPS_NAT_MASK &&
-           (!(ct->status & IPS_EXPECTED_BIT) || info->commit)) {
+           (ctinfo != IP_CT_RELATED || info->commit)) {
                /* NAT an established or related connection like before. */
                if (CTINFO2DIR(ctinfo) == IP_CT_DIR_REPLY)
                        /* This is the REPLY direction for a connection
@@ -968,7 +962,8 @@ static int parse_nat(const struct nlattr *attr,
                        break;
 
                case OVS_NAT_ATTR_IP_MIN:
-                       nla_memcpy(&info->range.min_addr, a, nla_len(a));
+                       nla_memcpy(&info->range.min_addr, a,
+                                  sizeof(info->range.min_addr));
                        info->range.flags |= NF_NAT_RANGE_MAP_IPS;
                        break;
 
@@ -1238,7 +1233,8 @@ static bool ovs_ct_nat_to_attr(const struct ovs_conntrack_info *info,
        }
 
        if (info->range.flags & NF_NAT_RANGE_MAP_IPS) {
-               if (info->family == NFPROTO_IPV4) {
+               if (IS_ENABLED(CONFIG_NF_NAT_IPV4) &&
+                   info->family == NFPROTO_IPV4) {
                        if (nla_put_in_addr(skb, OVS_NAT_ATTR_IP_MIN,
                                            info->range.min_addr.ip) ||
                            (info->range.max_addr.ip
@@ -1246,8 +1242,8 @@ static bool ovs_ct_nat_to_attr(const struct ovs_conntrack_info *info,
                             (nla_put_in_addr(skb, OVS_NAT_ATTR_IP_MAX,
                                              info->range.max_addr.ip))))
                                return false;
-#if IS_ENABLED(CONFIG_NF_NAT_IPV6)
-               } else if (info->family == NFPROTO_IPV6) {
+               } else if (IS_ENABLED(CONFIG_NF_NAT_IPV6) &&
+                          info->family == NFPROTO_IPV6) {
                        if (nla_put_in6_addr(skb, OVS_NAT_ATTR_IP_MIN,
                                             &info->range.min_addr.in6) ||
                            (memcmp(&info->range.max_addr.in6,
@@ -1256,7 +1252,6 @@ static bool ovs_ct_nat_to_attr(const struct ovs_conntrack_info *info,
                             (nla_put_in6_addr(skb, OVS_NAT_ATTR_IP_MAX,
                                               &info->range.max_addr.in6))))
                                return false;
-#endif
                } else {
                        return false;
                }
@@ -1342,7 +1337,7 @@ void ovs_ct_init(struct net *net)
        unsigned int n_bits = sizeof(struct ovs_key_ct_labels) * BITS_PER_BYTE;
        struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
 
-       if (nf_connlabels_get(net, n_bits)) {
+       if (nf_connlabels_get(net, n_bits - 1)) {
                ovs_net->xt_label = false;
                OVS_NLERR(true, "Failed to set connlabel length");
        } else {