ipv4,ipv6: grab rtnl before locking the socket
[cascardo/linux.git] / net / ipv4 / ip_sockglue.c
index 5cd9927..5171709 100644 (file)
@@ -536,12 +536,25 @@ out:
  *     Socket option code for IP. This is the end of the line after any
  *     TCP,UDP etc options on an IP socket.
  */
+static bool setsockopt_needs_rtnl(int optname)
+{
+       switch (optname) {
+       case IP_ADD_MEMBERSHIP:
+       case IP_ADD_SOURCE_MEMBERSHIP:
+       case IP_DROP_MEMBERSHIP:
+       case MCAST_JOIN_GROUP:
+       case MCAST_LEAVE_GROUP:
+               return true;
+       }
+       return false;
+}
 
 static int do_ip_setsockopt(struct sock *sk, int level,
                            int optname, char __user *optval, unsigned int optlen)
 {
        struct inet_sock *inet = inet_sk(sk);
        int val = 0, err;
+       bool needs_rtnl = setsockopt_needs_rtnl(optname);
 
        switch (optname) {
        case IP_PKTINFO:
@@ -584,6 +597,8 @@ static int do_ip_setsockopt(struct sock *sk, int level,
                return ip_mroute_setsockopt(sk, optname, optval, optlen);
 
        err = 0;
+       if (needs_rtnl)
+               rtnl_lock();
        lock_sock(sk);
 
        switch (optname) {
@@ -846,9 +861,9 @@ static int do_ip_setsockopt(struct sock *sk, int level,
                }
 
                if (optname == IP_ADD_MEMBERSHIP)
-                       err = ip_mc_join_group(sk, &mreq);
+                       err = __ip_mc_join_group(sk, &mreq);
                else
-                       err = ip_mc_leave_group(sk, &mreq);
+                       err = __ip_mc_leave_group(sk, &mreq);
                break;
        }
        case IP_MSFILTER:
@@ -913,7 +928,7 @@ static int do_ip_setsockopt(struct sock *sk, int level,
                        mreq.imr_multiaddr.s_addr = mreqs.imr_multiaddr;
                        mreq.imr_address.s_addr = mreqs.imr_interface;
                        mreq.imr_ifindex = 0;
-                       err = ip_mc_join_group(sk, &mreq);
+                       err = __ip_mc_join_group(sk, &mreq);
                        if (err && err != -EADDRINUSE)
                                break;
                        omode = MCAST_INCLUDE;
@@ -945,9 +960,9 @@ static int do_ip_setsockopt(struct sock *sk, int level,
                mreq.imr_ifindex = greq.gr_interface;
 
                if (optname == MCAST_JOIN_GROUP)
-                       err = ip_mc_join_group(sk, &mreq);
+                       err = __ip_mc_join_group(sk, &mreq);
                else
-                       err = ip_mc_leave_group(sk, &mreq);
+                       err = __ip_mc_leave_group(sk, &mreq);
                break;
        }
        case MCAST_JOIN_SOURCE_GROUP:
@@ -990,7 +1005,7 @@ static int do_ip_setsockopt(struct sock *sk, int level,
                        mreq.imr_multiaddr = psin->sin_addr;
                        mreq.imr_address.s_addr = 0;
                        mreq.imr_ifindex = greqs.gsr_interface;
-                       err = ip_mc_join_group(sk, &mreq);
+                       err = __ip_mc_join_group(sk, &mreq);
                        if (err && err != -EADDRINUSE)
                                break;
                        greqs.gsr_interface = mreq.imr_ifindex;
@@ -1118,10 +1133,14 @@ mc_msf_out:
                break;
        }
        release_sock(sk);
+       if (needs_rtnl)
+               rtnl_unlock();
        return err;
 
 e_inval:
        release_sock(sk);
+       if (needs_rtnl)
+               rtnl_unlock();
        return -EINVAL;
 }