net: get rid of an signed integer overflow in ip_idents_reserve()
[cascardo/linux.git] / net / ipv4 / route.c
index 60398a9..b5b47a2 100644 (file)
@@ -476,12 +476,18 @@ u32 ip_idents_reserve(u32 hash, int segs)
        atomic_t *p_id = ip_idents + hash % IP_IDENTS_SZ;
        u32 old = ACCESS_ONCE(*p_tstamp);
        u32 now = (u32)jiffies;
-       u32 delta = 0;
+       u32 new, delta = 0;
 
        if (old != now && cmpxchg(p_tstamp, old, now) == old)
                delta = prandom_u32_max(now - old);
 
-       return atomic_add_return(segs + delta, p_id) - segs;
+       /* Do not use atomic_add_return() as it makes UBSAN unhappy */
+       do {
+               old = (u32)atomic_read(p_id);
+               new = old + delta + segs;
+       } while (atomic_cmpxchg(p_id, old, new) != old);
+
+       return new - segs;
 }
 EXPORT_SYMBOL(ip_idents_reserve);
 
@@ -915,11 +921,11 @@ static int ip_error(struct sk_buff *skb)
        if (!IN_DEV_FORWARD(in_dev)) {
                switch (rt->dst.error) {
                case EHOSTUNREACH:
-                       IP_INC_STATS_BH(net, IPSTATS_MIB_INADDRERRORS);
+                       __IP_INC_STATS(net, IPSTATS_MIB_INADDRERRORS);
                        break;
 
                case ENETUNREACH:
-                       IP_INC_STATS_BH(net, IPSTATS_MIB_INNOROUTES);
+                       __IP_INC_STATS(net, IPSTATS_MIB_INNOROUTES);
                        break;
                }
                goto out;
@@ -934,7 +940,7 @@ static int ip_error(struct sk_buff *skb)
                break;
        case ENETUNREACH:
                code = ICMP_NET_UNREACH;
-               IP_INC_STATS_BH(net, IPSTATS_MIB_INNOROUTES);
+               __IP_INC_STATS(net, IPSTATS_MIB_INNOROUTES);
                break;
        case EACCES:
                code = ICMP_PKT_FILTERED;
@@ -2146,6 +2152,7 @@ struct rtable *__ip_route_output_key_hash(struct net *net, struct flowi4 *fl4,
        unsigned int flags = 0;
        struct fib_result res;
        struct rtable *rth;
+       int master_idx;
        int orig_oif;
        int err = -ENETUNREACH;
 
@@ -2155,6 +2162,9 @@ struct rtable *__ip_route_output_key_hash(struct net *net, struct flowi4 *fl4,
 
        orig_oif = fl4->flowi4_oif;
 
+       master_idx = l3mdev_master_ifindex_by_index(net, fl4->flowi4_oif);
+       if (master_idx)
+               fl4->flowi4_oif = master_idx;
        fl4->flowi4_iif = LOOPBACK_IFINDEX;
        fl4->flowi4_tos = tos & IPTOS_RT_MASK;
        fl4->flowi4_scope = ((tos & RTO_ONLINK) ?