Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[cascardo/linux.git] / net / netlink / af_netlink.c
index f22757a..15c731f 100644 (file)
@@ -1206,7 +1206,8 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
        struct module *module = NULL;
        struct mutex *cb_mutex;
        struct netlink_sock *nlk;
-       void (*bind)(int group);
+       int (*bind)(int group);
+       void (*unbind)(int group);
        int err = 0;
 
        sock->state = SS_UNCONNECTED;
@@ -1232,6 +1233,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
                err = -EPROTONOSUPPORT;
        cb_mutex = nl_table[protocol].cb_mutex;
        bind = nl_table[protocol].bind;
+       unbind = nl_table[protocol].unbind;
        netlink_unlock_table();
 
        if (err < 0)
@@ -1248,6 +1250,7 @@ static int netlink_create(struct net *net, struct socket *sock, int protocol,
        nlk = nlk_sk(sock->sk);
        nlk->module = module;
        nlk->netlink_bind = bind;
+       nlk->netlink_unbind = unbind;
 out:
        return err;
 
@@ -1301,6 +1304,7 @@ static int netlink_release(struct socket *sock)
                        kfree_rcu(old, rcu);
                        nl_table[sk->sk_protocol].module = NULL;
                        nl_table[sk->sk_protocol].bind = NULL;
+                       nl_table[sk->sk_protocol].unbind = NULL;
                        nl_table[sk->sk_protocol].flags = 0;
                        nl_table[sk->sk_protocol].registered = 0;
                }
@@ -1478,6 +1482,19 @@ static int netlink_realloc_groups(struct sock *sk)
        return err;
 }
 
+static void netlink_unbind(int group, long unsigned int groups,
+                          struct netlink_sock *nlk)
+{
+       int undo;
+
+       if (!nlk->netlink_unbind)
+               return;
+
+       for (undo = 0; undo < group; undo++)
+               if (test_bit(group, &groups))
+                       nlk->netlink_unbind(undo);
+}
+
 static int netlink_bind(struct socket *sock, struct sockaddr *addr,
                        int addr_len)
 {
@@ -1486,6 +1503,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
        struct netlink_sock *nlk = nlk_sk(sk);
        struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr;
        int err;
+       long unsigned int groups = nladdr->nl_groups;
 
        if (addr_len < sizeof(struct sockaddr_nl))
                return -EINVAL;
@@ -1494,7 +1512,7 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
                return -EINVAL;
 
        /* Only superuser is allowed to listen multicasts */
-       if (nladdr->nl_groups) {
+       if (groups) {
                if (!netlink_allowed(sock, NL_CFG_F_NONROOT_RECV))
                        return -EPERM;
                err = netlink_realloc_groups(sk);
@@ -1502,37 +1520,45 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr,
                        return err;
        }
 
-       if (nlk->portid) {
+       if (nlk->portid)
                if (nladdr->nl_pid != nlk->portid)
                        return -EINVAL;
-       } else {
+
+       if (nlk->netlink_bind && groups) {
+               int group;
+
+               for (group = 0; group < nlk->ngroups; group++) {
+                       if (!test_bit(group, &groups))
+                               continue;
+                       err = nlk->netlink_bind(group);
+                       if (!err)
+                               continue;
+                       netlink_unbind(group, groups, nlk);
+                       return err;
+               }
+       }
+
+       if (!nlk->portid) {
                err = nladdr->nl_pid ?
                        netlink_insert(sk, net, nladdr->nl_pid) :
                        netlink_autobind(sock);
-               if (err)
+               if (err) {
+                       netlink_unbind(nlk->ngroups - 1, groups, nlk);
                        return err;
+               }
        }
 
-       if (!nladdr->nl_groups && (nlk->groups == NULL || !(u32)nlk->groups[0]))
+       if (!groups && (nlk->groups == NULL || !(u32)nlk->groups[0]))
                return 0;
 
        netlink_table_grab();
        netlink_update_subscriptions(sk, nlk->subscriptions +
-                                        hweight32(nladdr->nl_groups) -
+                                        hweight32(groups) -
                                         hweight32(nlk->groups[0]));
-       nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups;
+       nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | groups;
        netlink_update_listeners(sk);
        netlink_table_ungrab();
 
-       if (nlk->netlink_bind && nlk->groups[0]) {
-               int i;
-
-               for (i = 0; i < nlk->ngroups; i++) {
-                       if (test_bit(i, nlk->groups))
-                               nlk->netlink_bind(i);
-               }
-       }
-
        return 0;
 }
 
@@ -2170,13 +2196,17 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
                        return err;
                if (!val || val - 1 >= nlk->ngroups)
                        return -EINVAL;
+               if (optname == NETLINK_ADD_MEMBERSHIP && nlk->netlink_bind) {
+                       err = nlk->netlink_bind(val);
+                       if (err)
+                               return err;
+               }
                netlink_table_grab();
                netlink_update_socket_mc(nlk, val,
                                         optname == NETLINK_ADD_MEMBERSHIP);
                netlink_table_ungrab();
-
-               if (nlk->netlink_bind)
-                       nlk->netlink_bind(val);
+               if (optname == NETLINK_DROP_MEMBERSHIP && nlk->netlink_unbind)
+                       nlk->netlink_unbind(val);
 
                err = 0;
                break;