Merge tag 'iio-fixes-for-4.9a' of git://git.kernel.org/pub/scm/linux/kernel/git/jic23...
[cascardo/linux.git] / net / ipv4 / inet_diag.c
index 38c2c47..e4d16fc 100644 (file)
@@ -45,6 +45,7 @@ struct inet_diag_entry {
        u16 family;
        u16 userlocks;
        u32 ifindex;
+       u32 mark;
 };
 
 static DEFINE_MUTEX(inet_diag_table_mutex);
@@ -98,6 +99,7 @@ static size_t inet_sk_attr_size(void)
                + nla_total_size(1) /* INET_DIAG_SHUTDOWN */
                + nla_total_size(1) /* INET_DIAG_TOS */
                + nla_total_size(1) /* INET_DIAG_TCLASS */
+               + nla_total_size(4) /* INET_DIAG_MARK */
                + nla_total_size(sizeof(struct inet_diag_meminfo))
                + nla_total_size(sizeof(struct inet_diag_msg))
                + nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
@@ -108,7 +110,8 @@ static size_t inet_sk_attr_size(void)
 
 int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
                             struct inet_diag_msg *r, int ext,
-                            struct user_namespace *user_ns)
+                            struct user_namespace *user_ns,
+                            bool net_admin)
 {
        const struct inet_sock *inet = inet_sk(sk);
 
@@ -135,6 +138,9 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
        }
 #endif
 
+       if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark))
+               goto errout;
+
        r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
        r->idiag_inode = sock_i_ino(sk);
 
@@ -148,7 +154,8 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
                      struct sk_buff *skb, const struct inet_diag_req_v2 *req,
                      struct user_namespace *user_ns,
                      u32 portid, u32 seq, u16 nlmsg_flags,
-                     const struct nlmsghdr *unlh)
+                     const struct nlmsghdr *unlh,
+                     bool net_admin)
 {
        const struct tcp_congestion_ops *ca_ops;
        const struct inet_diag_handler *handler;
@@ -174,7 +181,7 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
        r->idiag_timer = 0;
        r->idiag_retrans = 0;
 
-       if (inet_diag_msg_attrs_fill(sk, skb, r, ext, user_ns))
+       if (inet_diag_msg_attrs_fill(sk, skb, r, ext, user_ns, net_admin))
                goto errout;
 
        if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
@@ -273,10 +280,11 @@ static int inet_csk_diag_fill(struct sock *sk,
                              const struct inet_diag_req_v2 *req,
                              struct user_namespace *user_ns,
                              u32 portid, u32 seq, u16 nlmsg_flags,
-                             const struct nlmsghdr *unlh)
+                             const struct nlmsghdr *unlh,
+                             bool net_admin)
 {
-       return inet_sk_diag_fill(sk, inet_csk(sk), skb, req,
-                                user_ns, portid, seq, nlmsg_flags, unlh);
+       return inet_sk_diag_fill(sk, inet_csk(sk), skb, req, user_ns,
+                                portid, seq, nlmsg_flags, unlh, net_admin);
 }
 
 static int inet_twsk_diag_fill(struct sock *sk,
@@ -318,8 +326,9 @@ static int inet_twsk_diag_fill(struct sock *sk,
 
 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
                              u32 portid, u32 seq, u16 nlmsg_flags,
-                             const struct nlmsghdr *unlh)
+                             const struct nlmsghdr *unlh, bool net_admin)
 {
+       struct request_sock *reqsk = inet_reqsk(sk);
        struct inet_diag_msg *r;
        struct nlmsghdr *nlh;
        long tmo;
@@ -333,7 +342,7 @@ static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
        inet_diag_msg_common_fill(r, sk);
        r->idiag_state = TCP_SYN_RECV;
        r->idiag_timer = 1;
-       r->idiag_retrans = inet_reqsk(sk)->num_retrans;
+       r->idiag_retrans = reqsk->num_retrans;
 
        BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
                     offsetof(struct sock, sk_cookie));
@@ -345,6 +354,10 @@ static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
        r->idiag_uid    = 0;
        r->idiag_inode  = 0;
 
+       if (net_admin && nla_put_u32(skb, INET_DIAG_MARK,
+                                    inet_rsk(reqsk)->ir_mark))
+               return -EMSGSIZE;
+
        nlmsg_end(skb, nlh);
        return 0;
 }
@@ -353,7 +366,7 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
                        const struct inet_diag_req_v2 *r,
                        struct user_namespace *user_ns,
                        u32 portid, u32 seq, u16 nlmsg_flags,
-                       const struct nlmsghdr *unlh)
+                       const struct nlmsghdr *unlh, bool net_admin)
 {
        if (sk->sk_state == TCP_TIME_WAIT)
                return inet_twsk_diag_fill(sk, skb, portid, seq,
@@ -361,10 +374,10 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
 
        if (sk->sk_state == TCP_NEW_SYN_RECV)
                return inet_req_diag_fill(sk, skb, portid, seq,
-                                         nlmsg_flags, unlh);
+                                         nlmsg_flags, unlh, net_admin);
 
        return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
-                                 nlmsg_flags, unlh);
+                                 nlmsg_flags, unlh, net_admin);
 }
 
 struct sock *inet_diag_find_one_icsk(struct net *net,
@@ -434,7 +447,8 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
        err = sk_diag_fill(sk, rep, req,
                           sk_user_ns(NETLINK_CB(in_skb).sk),
                           NETLINK_CB(in_skb).portid,
-                          nlh->nlmsg_seq, 0, nlh);
+                          nlh->nlmsg_seq, 0, nlh,
+                          netlink_net_capable(in_skb, CAP_NET_ADMIN));
        if (err < 0) {
                WARN_ON(err == -EMSGSIZE);
                nlmsg_free(rep);
@@ -580,6 +594,14 @@ static int inet_diag_bc_run(const struct nlattr *_bc,
                                yes = 0;
                        break;
                }
+               case INET_DIAG_BC_MARK_COND: {
+                       struct inet_diag_markcond *cond;
+
+                       cond = (struct inet_diag_markcond *)(op + 1);
+                       if ((entry->mark & cond->mask) != cond->mark)
+                               yes = 0;
+                       break;
+               }
                }
 
                if (yes) {
@@ -624,6 +646,12 @@ int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
        entry.dport = ntohs(inet->inet_dport);
        entry.ifindex = sk->sk_bound_dev_if;
        entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
+       if (sk_fullsock(sk))
+               entry.mark = sk->sk_mark;
+       else if (sk->sk_state == TCP_NEW_SYN_RECV)
+               entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
+       else
+               entry.mark = 0;
 
        return inet_diag_bc_run(bc, &entry);
 }
@@ -706,10 +734,25 @@ static bool valid_port_comparison(const struct inet_diag_bc_op *op,
        return true;
 }
 
-static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
+static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
+                          int *min_len)
+{
+       *min_len += sizeof(struct inet_diag_markcond);
+       return len >= *min_len;
+}
+
+static int inet_diag_bc_audit(const struct nlattr *attr,
+                             const struct sk_buff *skb)
 {
-       const void *bc = bytecode;
-       int  len = bytecode_len;
+       bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
+       const void *bytecode, *bc;
+       int bytecode_len, len;
+
+       if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
+               return -EINVAL;
+
+       bytecode = bc = nla_data(attr);
+       len = bytecode_len = nla_len(attr);
 
        while (len > 0) {
                int min_len = sizeof(struct inet_diag_bc_op);
@@ -732,6 +775,12 @@ static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
                        if (!valid_port_comparison(bc, len, &min_len))
                                return -EINVAL;
                        break;
+               case INET_DIAG_BC_MARK_COND:
+                       if (!net_admin)
+                               return -EPERM;
+                       if (!valid_markcond(bc, len, &min_len))
+                               return -EINVAL;
+                       break;
                case INET_DIAG_BC_AUTO:
                case INET_DIAG_BC_JMP:
                case INET_DIAG_BC_NOP:
@@ -760,7 +809,8 @@ static int inet_csk_diag_dump(struct sock *sk,
                              struct sk_buff *skb,
                              struct netlink_callback *cb,
                              const struct inet_diag_req_v2 *r,
-                             const struct nlattr *bc)
+                             const struct nlattr *bc,
+                             bool net_admin)
 {
        if (!inet_diag_bc_sk(bc, sk))
                return 0;
@@ -768,7 +818,8 @@ static int inet_csk_diag_dump(struct sock *sk,
        return inet_csk_diag_fill(sk, skb, r,
                                  sk_user_ns(NETLINK_CB(cb->skb).sk),
                                  NETLINK_CB(cb->skb).portid,
-                                 cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
+                                 cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh,
+                                 net_admin);
 }
 
 static void twsk_build_assert(void)
@@ -804,6 +855,7 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
        struct net *net = sock_net(skb->sk);
        int i, num, s_i, s_num;
        u32 idiag_states = r->idiag_states;
+       bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
 
        if (idiag_states & TCPF_SYN_RECV)
                idiag_states |= TCPF_NEW_SYN_RECV;
@@ -844,7 +896,8 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
                                    cb->args[3] > 0)
                                        goto next_listen;
 
-                               if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
+                               if (inet_csk_diag_dump(sk, skb, cb, r,
+                                                      bc, net_admin) < 0) {
                                        spin_unlock_bh(&ilb->lock);
                                        goto done;
                                }
@@ -912,7 +965,7 @@ skip_listen_ht:
                                           sk_user_ns(NETLINK_CB(cb->skb).sk),
                                           NETLINK_CB(cb->skb).portid,
                                           cb->nlh->nlmsg_seq, NLM_F_MULTI,
-                                          cb->nlh);
+                                          cb->nlh, net_admin);
                        if (res < 0) {
                                spin_unlock_bh(lock);
                                goto done;
@@ -1020,13 +1073,13 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
        if (nlh->nlmsg_flags & NLM_F_DUMP) {
                if (nlmsg_attrlen(nlh, hdrlen)) {
                        struct nlattr *attr;
+                       int err;
 
                        attr = nlmsg_find_attr(nlh, hdrlen,
                                               INET_DIAG_REQ_BYTECODE);
-                       if (!attr ||
-                           nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
-                           inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
-                               return -EINVAL;
+                       err = inet_diag_bc_audit(attr, skb);
+                       if (err)
+                               return err;
                }
                {
                        struct netlink_dump_control c = {
@@ -1051,13 +1104,13 @@ static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
            h->nlmsg_flags & NLM_F_DUMP) {
                if (nlmsg_attrlen(h, hdrlen)) {
                        struct nlattr *attr;
+                       int err;
 
                        attr = nlmsg_find_attr(h, hdrlen,
                                               INET_DIAG_REQ_BYTECODE);
-                       if (!attr ||
-                           nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
-                           inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
-                               return -EINVAL;
+                       err = inet_diag_bc_audit(attr, skb);
+                       if (err)
+                               return err;
                }
                {
                        struct netlink_dump_control c = {