Merge tag 'armsoc-arm64' of git://git.kernel.org/pub/scm/linux/kernel/git/arm/arm-soc
[cascardo/linux.git] / net / ipv4 / inet_diag.c
1 /*
2  * inet_diag.c  Module for monitoring INET transport protocols sockets.
3  *
4  * Authors:     Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
5  *
6  *      This program is free software; you can redistribute it and/or
7  *      modify it under the terms of the GNU General Public License
8  *      as published by the Free Software Foundation; either version
9  *      2 of the License, or (at your option) any later version.
10  */
11
12 #include <linux/kernel.h>
13 #include <linux/module.h>
14 #include <linux/types.h>
15 #include <linux/fcntl.h>
16 #include <linux/random.h>
17 #include <linux/slab.h>
18 #include <linux/cache.h>
19 #include <linux/init.h>
20 #include <linux/time.h>
21
22 #include <net/icmp.h>
23 #include <net/tcp.h>
24 #include <net/ipv6.h>
25 #include <net/inet_common.h>
26 #include <net/inet_connection_sock.h>
27 #include <net/inet_hashtables.h>
28 #include <net/inet_timewait_sock.h>
29 #include <net/inet6_hashtables.h>
30 #include <net/netlink.h>
31
32 #include <linux/inet.h>
33 #include <linux/stddef.h>
34
35 #include <linux/inet_diag.h>
36 #include <linux/sock_diag.h>
37
38 static const struct inet_diag_handler **inet_diag_table;
39
40 struct inet_diag_entry {
41         const __be32 *saddr;
42         const __be32 *daddr;
43         u16 sport;
44         u16 dport;
45         u16 family;
46         u16 userlocks;
47 };
48
49 static DEFINE_MUTEX(inet_diag_table_mutex);
50
51 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
52 {
53         if (!inet_diag_table[proto])
54                 request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
55                                NETLINK_SOCK_DIAG, AF_INET, proto);
56
57         mutex_lock(&inet_diag_table_mutex);
58         if (!inet_diag_table[proto])
59                 return ERR_PTR(-ENOENT);
60
61         return inet_diag_table[proto];
62 }
63
64 static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
65 {
66         mutex_unlock(&inet_diag_table_mutex);
67 }
68
69 void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
70 {
71         r->idiag_family = sk->sk_family;
72
73         r->id.idiag_sport = htons(sk->sk_num);
74         r->id.idiag_dport = sk->sk_dport;
75         r->id.idiag_if = sk->sk_bound_dev_if;
76         sock_diag_save_cookie(sk, r->id.idiag_cookie);
77
78 #if IS_ENABLED(CONFIG_IPV6)
79         if (sk->sk_family == AF_INET6) {
80                 *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
81                 *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
82         } else
83 #endif
84         {
85         memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
86         memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
87
88         r->id.idiag_src[0] = sk->sk_rcv_saddr;
89         r->id.idiag_dst[0] = sk->sk_daddr;
90         }
91 }
92 EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill);
93
94 static size_t inet_sk_attr_size(void)
95 {
96         return    nla_total_size(sizeof(struct tcp_info))
97                 + nla_total_size(1) /* INET_DIAG_SHUTDOWN */
98                 + nla_total_size(1) /* INET_DIAG_TOS */
99                 + nla_total_size(1) /* INET_DIAG_TCLASS */
100                 + nla_total_size(sizeof(struct inet_diag_meminfo))
101                 + nla_total_size(sizeof(struct inet_diag_msg))
102                 + nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
103                 + nla_total_size(TCP_CA_NAME_MAX)
104                 + nla_total_size(sizeof(struct tcpvegas_info))
105                 + 64;
106 }
107
108 int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
109                              struct inet_diag_msg *r, int ext,
110                              struct user_namespace *user_ns)
111 {
112         const struct inet_sock *inet = inet_sk(sk);
113
114         if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
115                 goto errout;
116
117         /* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
118          * hence this needs to be included regardless of socket family.
119          */
120         if (ext & (1 << (INET_DIAG_TOS - 1)))
121                 if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
122                         goto errout;
123
124 #if IS_ENABLED(CONFIG_IPV6)
125         if (r->idiag_family == AF_INET6) {
126                 if (ext & (1 << (INET_DIAG_TCLASS - 1)))
127                         if (nla_put_u8(skb, INET_DIAG_TCLASS,
128                                        inet6_sk(sk)->tclass) < 0)
129                                 goto errout;
130
131                 if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
132                     nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
133                         goto errout;
134         }
135 #endif
136
137         r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
138         r->idiag_inode = sock_i_ino(sk);
139
140         return 0;
141 errout:
142         return 1;
143 }
144 EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
145
146 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
147                       struct sk_buff *skb, const struct inet_diag_req_v2 *req,
148                       struct user_namespace *user_ns,
149                       u32 portid, u32 seq, u16 nlmsg_flags,
150                       const struct nlmsghdr *unlh)
151 {
152         const struct tcp_congestion_ops *ca_ops;
153         const struct inet_diag_handler *handler;
154         int ext = req->idiag_ext;
155         struct inet_diag_msg *r;
156         struct nlmsghdr  *nlh;
157         struct nlattr *attr;
158         void *info = NULL;
159
160         handler = inet_diag_table[req->sdiag_protocol];
161         BUG_ON(!handler);
162
163         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
164                         nlmsg_flags);
165         if (!nlh)
166                 return -EMSGSIZE;
167
168         r = nlmsg_data(nlh);
169         BUG_ON(!sk_fullsock(sk));
170
171         inet_diag_msg_common_fill(r, sk);
172         r->idiag_state = sk->sk_state;
173         r->idiag_timer = 0;
174         r->idiag_retrans = 0;
175
176         if (inet_diag_msg_attrs_fill(sk, skb, r, ext, user_ns))
177                 goto errout;
178
179         if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
180                 struct inet_diag_meminfo minfo = {
181                         .idiag_rmem = sk_rmem_alloc_get(sk),
182                         .idiag_wmem = sk->sk_wmem_queued,
183                         .idiag_fmem = sk->sk_forward_alloc,
184                         .idiag_tmem = sk_wmem_alloc_get(sk),
185                 };
186
187                 if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
188                         goto errout;
189         }
190
191         if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
192                 if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
193                         goto errout;
194
195         if (!icsk) {
196                 handler->idiag_get_info(sk, r, NULL);
197                 goto out;
198         }
199
200         if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
201             icsk->icsk_pending == ICSK_TIME_EARLY_RETRANS ||
202             icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
203                 r->idiag_timer = 1;
204                 r->idiag_retrans = icsk->icsk_retransmits;
205                 r->idiag_expires =
206                         jiffies_to_msecs(icsk->icsk_timeout - jiffies);
207         } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
208                 r->idiag_timer = 4;
209                 r->idiag_retrans = icsk->icsk_probes_out;
210                 r->idiag_expires =
211                         jiffies_to_msecs(icsk->icsk_timeout - jiffies);
212         } else if (timer_pending(&sk->sk_timer)) {
213                 r->idiag_timer = 2;
214                 r->idiag_retrans = icsk->icsk_probes_out;
215                 r->idiag_expires =
216                         jiffies_to_msecs(sk->sk_timer.expires - jiffies);
217         } else {
218                 r->idiag_timer = 0;
219                 r->idiag_expires = 0;
220         }
221
222         if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
223                 attr = nla_reserve_64bit(skb, INET_DIAG_INFO,
224                                          handler->idiag_info_size,
225                                          INET_DIAG_PAD);
226                 if (!attr)
227                         goto errout;
228
229                 info = nla_data(attr);
230         }
231
232         if (ext & (1 << (INET_DIAG_CONG - 1))) {
233                 int err = 0;
234
235                 rcu_read_lock();
236                 ca_ops = READ_ONCE(icsk->icsk_ca_ops);
237                 if (ca_ops)
238                         err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
239                 rcu_read_unlock();
240                 if (err < 0)
241                         goto errout;
242         }
243
244         handler->idiag_get_info(sk, r, info);
245
246         if (sk->sk_state < TCP_TIME_WAIT) {
247                 union tcp_cc_info info;
248                 size_t sz = 0;
249                 int attr;
250
251                 rcu_read_lock();
252                 ca_ops = READ_ONCE(icsk->icsk_ca_ops);
253                 if (ca_ops && ca_ops->get_info)
254                         sz = ca_ops->get_info(sk, ext, &attr, &info);
255                 rcu_read_unlock();
256                 if (sz && nla_put(skb, attr, sz, &info) < 0)
257                         goto errout;
258         }
259
260 out:
261         nlmsg_end(skb, nlh);
262         return 0;
263
264 errout:
265         nlmsg_cancel(skb, nlh);
266         return -EMSGSIZE;
267 }
268 EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
269
270 static int inet_csk_diag_fill(struct sock *sk,
271                               struct sk_buff *skb,
272                               const struct inet_diag_req_v2 *req,
273                               struct user_namespace *user_ns,
274                               u32 portid, u32 seq, u16 nlmsg_flags,
275                               const struct nlmsghdr *unlh)
276 {
277         return inet_sk_diag_fill(sk, inet_csk(sk), skb, req,
278                                  user_ns, portid, seq, nlmsg_flags, unlh);
279 }
280
281 static int inet_twsk_diag_fill(struct sock *sk,
282                                struct sk_buff *skb,
283                                u32 portid, u32 seq, u16 nlmsg_flags,
284                                const struct nlmsghdr *unlh)
285 {
286         struct inet_timewait_sock *tw = inet_twsk(sk);
287         struct inet_diag_msg *r;
288         struct nlmsghdr *nlh;
289         long tmo;
290
291         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
292                         nlmsg_flags);
293         if (!nlh)
294                 return -EMSGSIZE;
295
296         r = nlmsg_data(nlh);
297         BUG_ON(tw->tw_state != TCP_TIME_WAIT);
298
299         tmo = tw->tw_timer.expires - jiffies;
300         if (tmo < 0)
301                 tmo = 0;
302
303         inet_diag_msg_common_fill(r, sk);
304         r->idiag_retrans      = 0;
305
306         r->idiag_state        = tw->tw_substate;
307         r->idiag_timer        = 3;
308         r->idiag_expires      = jiffies_to_msecs(tmo);
309         r->idiag_rqueue       = 0;
310         r->idiag_wqueue       = 0;
311         r->idiag_uid          = 0;
312         r->idiag_inode        = 0;
313
314         nlmsg_end(skb, nlh);
315         return 0;
316 }
317
318 static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
319                               u32 portid, u32 seq, u16 nlmsg_flags,
320                               const struct nlmsghdr *unlh)
321 {
322         struct inet_diag_msg *r;
323         struct nlmsghdr *nlh;
324         long tmo;
325
326         nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
327                         nlmsg_flags);
328         if (!nlh)
329                 return -EMSGSIZE;
330
331         r = nlmsg_data(nlh);
332         inet_diag_msg_common_fill(r, sk);
333         r->idiag_state = TCP_SYN_RECV;
334         r->idiag_timer = 1;
335         r->idiag_retrans = inet_reqsk(sk)->num_retrans;
336
337         BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
338                      offsetof(struct sock, sk_cookie));
339
340         tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
341         r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
342         r->idiag_rqueue = 0;
343         r->idiag_wqueue = 0;
344         r->idiag_uid    = 0;
345         r->idiag_inode  = 0;
346
347         nlmsg_end(skb, nlh);
348         return 0;
349 }
350
351 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
352                         const struct inet_diag_req_v2 *r,
353                         struct user_namespace *user_ns,
354                         u32 portid, u32 seq, u16 nlmsg_flags,
355                         const struct nlmsghdr *unlh)
356 {
357         if (sk->sk_state == TCP_TIME_WAIT)
358                 return inet_twsk_diag_fill(sk, skb, portid, seq,
359                                            nlmsg_flags, unlh);
360
361         if (sk->sk_state == TCP_NEW_SYN_RECV)
362                 return inet_req_diag_fill(sk, skb, portid, seq,
363                                           nlmsg_flags, unlh);
364
365         return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
366                                   nlmsg_flags, unlh);
367 }
368
369 struct sock *inet_diag_find_one_icsk(struct net *net,
370                                      struct inet_hashinfo *hashinfo,
371                                      const struct inet_diag_req_v2 *req)
372 {
373         struct sock *sk;
374
375         rcu_read_lock();
376         if (req->sdiag_family == AF_INET)
377                 sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0],
378                                  req->id.idiag_dport, req->id.idiag_src[0],
379                                  req->id.idiag_sport, req->id.idiag_if);
380 #if IS_ENABLED(CONFIG_IPV6)
381         else if (req->sdiag_family == AF_INET6) {
382                 if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
383                     ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
384                         sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3],
385                                          req->id.idiag_dport, req->id.idiag_src[3],
386                                          req->id.idiag_sport, req->id.idiag_if);
387                 else
388                         sk = inet6_lookup(net, hashinfo, NULL, 0,
389                                           (struct in6_addr *)req->id.idiag_dst,
390                                           req->id.idiag_dport,
391                                           (struct in6_addr *)req->id.idiag_src,
392                                           req->id.idiag_sport,
393                                           req->id.idiag_if);
394         }
395 #endif
396         else {
397                 rcu_read_unlock();
398                 return ERR_PTR(-EINVAL);
399         }
400         rcu_read_unlock();
401         if (!sk)
402                 return ERR_PTR(-ENOENT);
403
404         if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) {
405                 sock_gen_put(sk);
406                 return ERR_PTR(-ENOENT);
407         }
408
409         return sk;
410 }
411 EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
412
413 int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
414                             struct sk_buff *in_skb,
415                             const struct nlmsghdr *nlh,
416                             const struct inet_diag_req_v2 *req)
417 {
418         struct net *net = sock_net(in_skb->sk);
419         struct sk_buff *rep;
420         struct sock *sk;
421         int err;
422
423         sk = inet_diag_find_one_icsk(net, hashinfo, req);
424         if (IS_ERR(sk))
425                 return PTR_ERR(sk);
426
427         rep = nlmsg_new(inet_sk_attr_size(), GFP_KERNEL);
428         if (!rep) {
429                 err = -ENOMEM;
430                 goto out;
431         }
432
433         err = sk_diag_fill(sk, rep, req,
434                            sk_user_ns(NETLINK_CB(in_skb).sk),
435                            NETLINK_CB(in_skb).portid,
436                            nlh->nlmsg_seq, 0, nlh);
437         if (err < 0) {
438                 WARN_ON(err == -EMSGSIZE);
439                 nlmsg_free(rep);
440                 goto out;
441         }
442         err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
443                               MSG_DONTWAIT);
444         if (err > 0)
445                 err = 0;
446
447 out:
448         if (sk)
449                 sock_gen_put(sk);
450
451         return err;
452 }
453 EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
454
455 static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
456                                const struct nlmsghdr *nlh,
457                                const struct inet_diag_req_v2 *req)
458 {
459         const struct inet_diag_handler *handler;
460         int err;
461
462         handler = inet_diag_lock_handler(req->sdiag_protocol);
463         if (IS_ERR(handler))
464                 err = PTR_ERR(handler);
465         else if (cmd == SOCK_DIAG_BY_FAMILY)
466                 err = handler->dump_one(in_skb, nlh, req);
467         else if (cmd == SOCK_DESTROY && handler->destroy)
468                 err = handler->destroy(in_skb, req);
469         else
470                 err = -EOPNOTSUPP;
471         inet_diag_unlock_handler(handler);
472
473         return err;
474 }
475
476 static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
477 {
478         int words = bits >> 5;
479
480         bits &= 0x1f;
481
482         if (words) {
483                 if (memcmp(a1, a2, words << 2))
484                         return 0;
485         }
486         if (bits) {
487                 __be32 w1, w2;
488                 __be32 mask;
489
490                 w1 = a1[words];
491                 w2 = a2[words];
492
493                 mask = htonl((0xffffffff) << (32 - bits));
494
495                 if ((w1 ^ w2) & mask)
496                         return 0;
497         }
498
499         return 1;
500 }
501
502 static int inet_diag_bc_run(const struct nlattr *_bc,
503                             const struct inet_diag_entry *entry)
504 {
505         const void *bc = nla_data(_bc);
506         int len = nla_len(_bc);
507
508         while (len > 0) {
509                 int yes = 1;
510                 const struct inet_diag_bc_op *op = bc;
511
512                 switch (op->code) {
513                 case INET_DIAG_BC_NOP:
514                         break;
515                 case INET_DIAG_BC_JMP:
516                         yes = 0;
517                         break;
518                 case INET_DIAG_BC_S_GE:
519                         yes = entry->sport >= op[1].no;
520                         break;
521                 case INET_DIAG_BC_S_LE:
522                         yes = entry->sport <= op[1].no;
523                         break;
524                 case INET_DIAG_BC_D_GE:
525                         yes = entry->dport >= op[1].no;
526                         break;
527                 case INET_DIAG_BC_D_LE:
528                         yes = entry->dport <= op[1].no;
529                         break;
530                 case INET_DIAG_BC_AUTO:
531                         yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
532                         break;
533                 case INET_DIAG_BC_S_COND:
534                 case INET_DIAG_BC_D_COND: {
535                         const struct inet_diag_hostcond *cond;
536                         const __be32 *addr;
537
538                         cond = (const struct inet_diag_hostcond *)(op + 1);
539                         if (cond->port != -1 &&
540                             cond->port != (op->code == INET_DIAG_BC_S_COND ?
541                                              entry->sport : entry->dport)) {
542                                 yes = 0;
543                                 break;
544                         }
545
546                         if (op->code == INET_DIAG_BC_S_COND)
547                                 addr = entry->saddr;
548                         else
549                                 addr = entry->daddr;
550
551                         if (cond->family != AF_UNSPEC &&
552                             cond->family != entry->family) {
553                                 if (entry->family == AF_INET6 &&
554                                     cond->family == AF_INET) {
555                                         if (addr[0] == 0 && addr[1] == 0 &&
556                                             addr[2] == htonl(0xffff) &&
557                                             bitstring_match(addr + 3,
558                                                             cond->addr,
559                                                             cond->prefix_len))
560                                                 break;
561                                 }
562                                 yes = 0;
563                                 break;
564                         }
565
566                         if (cond->prefix_len == 0)
567                                 break;
568                         if (bitstring_match(addr, cond->addr,
569                                             cond->prefix_len))
570                                 break;
571                         yes = 0;
572                         break;
573                 }
574                 }
575
576                 if (yes) {
577                         len -= op->yes;
578                         bc += op->yes;
579                 } else {
580                         len -= op->no;
581                         bc += op->no;
582                 }
583         }
584         return len == 0;
585 }
586
587 /* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
588  */
589 static void entry_fill_addrs(struct inet_diag_entry *entry,
590                              const struct sock *sk)
591 {
592 #if IS_ENABLED(CONFIG_IPV6)
593         if (sk->sk_family == AF_INET6) {
594                 entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
595                 entry->daddr = sk->sk_v6_daddr.s6_addr32;
596         } else
597 #endif
598         {
599                 entry->saddr = &sk->sk_rcv_saddr;
600                 entry->daddr = &sk->sk_daddr;
601         }
602 }
603
604 int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
605 {
606         struct inet_sock *inet = inet_sk(sk);
607         struct inet_diag_entry entry;
608
609         if (!bc)
610                 return 1;
611
612         entry.family = sk->sk_family;
613         entry_fill_addrs(&entry, sk);
614         entry.sport = inet->inet_num;
615         entry.dport = ntohs(inet->inet_dport);
616         entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
617
618         return inet_diag_bc_run(bc, &entry);
619 }
620 EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
621
622 static int valid_cc(const void *bc, int len, int cc)
623 {
624         while (len >= 0) {
625                 const struct inet_diag_bc_op *op = bc;
626
627                 if (cc > len)
628                         return 0;
629                 if (cc == len)
630                         return 1;
631                 if (op->yes < 4 || op->yes & 3)
632                         return 0;
633                 len -= op->yes;
634                 bc  += op->yes;
635         }
636         return 0;
637 }
638
639 /* Validate an inet_diag_hostcond. */
640 static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
641                            int *min_len)
642 {
643         struct inet_diag_hostcond *cond;
644         int addr_len;
645
646         /* Check hostcond space. */
647         *min_len += sizeof(struct inet_diag_hostcond);
648         if (len < *min_len)
649                 return false;
650         cond = (struct inet_diag_hostcond *)(op + 1);
651
652         /* Check address family and address length. */
653         switch (cond->family) {
654         case AF_UNSPEC:
655                 addr_len = 0;
656                 break;
657         case AF_INET:
658                 addr_len = sizeof(struct in_addr);
659                 break;
660         case AF_INET6:
661                 addr_len = sizeof(struct in6_addr);
662                 break;
663         default:
664                 return false;
665         }
666         *min_len += addr_len;
667         if (len < *min_len)
668                 return false;
669
670         /* Check prefix length (in bits) vs address length (in bytes). */
671         if (cond->prefix_len > 8 * addr_len)
672                 return false;
673
674         return true;
675 }
676
677 /* Validate a port comparison operator. */
678 static bool valid_port_comparison(const struct inet_diag_bc_op *op,
679                                   int len, int *min_len)
680 {
681         /* Port comparisons put the port in a follow-on inet_diag_bc_op. */
682         *min_len += sizeof(struct inet_diag_bc_op);
683         if (len < *min_len)
684                 return false;
685         return true;
686 }
687
688 static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
689 {
690         const void *bc = bytecode;
691         int  len = bytecode_len;
692
693         while (len > 0) {
694                 int min_len = sizeof(struct inet_diag_bc_op);
695                 const struct inet_diag_bc_op *op = bc;
696
697                 switch (op->code) {
698                 case INET_DIAG_BC_S_COND:
699                 case INET_DIAG_BC_D_COND:
700                         if (!valid_hostcond(bc, len, &min_len))
701                                 return -EINVAL;
702                         break;
703                 case INET_DIAG_BC_S_GE:
704                 case INET_DIAG_BC_S_LE:
705                 case INET_DIAG_BC_D_GE:
706                 case INET_DIAG_BC_D_LE:
707                         if (!valid_port_comparison(bc, len, &min_len))
708                                 return -EINVAL;
709                         break;
710                 case INET_DIAG_BC_AUTO:
711                 case INET_DIAG_BC_JMP:
712                 case INET_DIAG_BC_NOP:
713                         break;
714                 default:
715                         return -EINVAL;
716                 }
717
718                 if (op->code != INET_DIAG_BC_NOP) {
719                         if (op->no < min_len || op->no > len + 4 || op->no & 3)
720                                 return -EINVAL;
721                         if (op->no < len &&
722                             !valid_cc(bytecode, bytecode_len, len - op->no))
723                                 return -EINVAL;
724                 }
725
726                 if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
727                         return -EINVAL;
728                 bc  += op->yes;
729                 len -= op->yes;
730         }
731         return len == 0 ? 0 : -EINVAL;
732 }
733
734 static int inet_csk_diag_dump(struct sock *sk,
735                               struct sk_buff *skb,
736                               struct netlink_callback *cb,
737                               const struct inet_diag_req_v2 *r,
738                               const struct nlattr *bc)
739 {
740         if (!inet_diag_bc_sk(bc, sk))
741                 return 0;
742
743         return inet_csk_diag_fill(sk, skb, r,
744                                   sk_user_ns(NETLINK_CB(cb->skb).sk),
745                                   NETLINK_CB(cb->skb).portid,
746                                   cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
747 }
748
749 static void twsk_build_assert(void)
750 {
751         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
752                      offsetof(struct sock, sk_family));
753
754         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
755                      offsetof(struct inet_sock, inet_num));
756
757         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
758                      offsetof(struct inet_sock, inet_dport));
759
760         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
761                      offsetof(struct inet_sock, inet_rcv_saddr));
762
763         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
764                      offsetof(struct inet_sock, inet_daddr));
765
766 #if IS_ENABLED(CONFIG_IPV6)
767         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
768                      offsetof(struct sock, sk_v6_rcv_saddr));
769
770         BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
771                      offsetof(struct sock, sk_v6_daddr));
772 #endif
773 }
774
775 void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
776                          struct netlink_callback *cb,
777                          const struct inet_diag_req_v2 *r, struct nlattr *bc)
778 {
779         struct net *net = sock_net(skb->sk);
780         int i, num, s_i, s_num;
781         u32 idiag_states = r->idiag_states;
782
783         if (idiag_states & TCPF_SYN_RECV)
784                 idiag_states |= TCPF_NEW_SYN_RECV;
785         s_i = cb->args[1];
786         s_num = num = cb->args[2];
787
788         if (cb->args[0] == 0) {
789                 if (!(idiag_states & TCPF_LISTEN))
790                         goto skip_listen_ht;
791
792                 for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
793                         struct inet_listen_hashbucket *ilb;
794                         struct sock *sk;
795
796                         num = 0;
797                         ilb = &hashinfo->listening_hash[i];
798                         spin_lock_bh(&ilb->lock);
799                         sk_for_each(sk, &ilb->head) {
800                                 struct inet_sock *inet = inet_sk(sk);
801
802                                 if (!net_eq(sock_net(sk), net))
803                                         continue;
804
805                                 if (num < s_num) {
806                                         num++;
807                                         continue;
808                                 }
809
810                                 if (r->sdiag_family != AF_UNSPEC &&
811                                     sk->sk_family != r->sdiag_family)
812                                         goto next_listen;
813
814                                 if (r->id.idiag_sport != inet->inet_sport &&
815                                     r->id.idiag_sport)
816                                         goto next_listen;
817
818                                 if (r->id.idiag_dport ||
819                                     cb->args[3] > 0)
820                                         goto next_listen;
821
822                                 if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) {
823                                         spin_unlock_bh(&ilb->lock);
824                                         goto done;
825                                 }
826
827 next_listen:
828                                 cb->args[3] = 0;
829                                 cb->args[4] = 0;
830                                 ++num;
831                         }
832                         spin_unlock_bh(&ilb->lock);
833
834                         s_num = 0;
835                         cb->args[3] = 0;
836                         cb->args[4] = 0;
837                 }
838 skip_listen_ht:
839                 cb->args[0] = 1;
840                 s_i = num = s_num = 0;
841         }
842
843         if (!(idiag_states & ~TCPF_LISTEN))
844                 goto out;
845
846         for (i = s_i; i <= hashinfo->ehash_mask; i++) {
847                 struct inet_ehash_bucket *head = &hashinfo->ehash[i];
848                 spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
849                 struct hlist_nulls_node *node;
850                 struct sock *sk;
851
852                 num = 0;
853
854                 if (hlist_nulls_empty(&head->chain))
855                         continue;
856
857                 if (i > s_i)
858                         s_num = 0;
859
860                 spin_lock_bh(lock);
861                 sk_nulls_for_each(sk, node, &head->chain) {
862                         int state, res;
863
864                         if (!net_eq(sock_net(sk), net))
865                                 continue;
866                         if (num < s_num)
867                                 goto next_normal;
868                         state = (sk->sk_state == TCP_TIME_WAIT) ?
869                                 inet_twsk(sk)->tw_substate : sk->sk_state;
870                         if (!(idiag_states & (1 << state)))
871                                 goto next_normal;
872                         if (r->sdiag_family != AF_UNSPEC &&
873                             sk->sk_family != r->sdiag_family)
874                                 goto next_normal;
875                         if (r->id.idiag_sport != htons(sk->sk_num) &&
876                             r->id.idiag_sport)
877                                 goto next_normal;
878                         if (r->id.idiag_dport != sk->sk_dport &&
879                             r->id.idiag_dport)
880                                 goto next_normal;
881                         twsk_build_assert();
882
883                         if (!inet_diag_bc_sk(bc, sk))
884                                 goto next_normal;
885
886                         res = sk_diag_fill(sk, skb, r,
887                                            sk_user_ns(NETLINK_CB(cb->skb).sk),
888                                            NETLINK_CB(cb->skb).portid,
889                                            cb->nlh->nlmsg_seq, NLM_F_MULTI,
890                                            cb->nlh);
891                         if (res < 0) {
892                                 spin_unlock_bh(lock);
893                                 goto done;
894                         }
895 next_normal:
896                         ++num;
897                 }
898
899                 spin_unlock_bh(lock);
900                 cond_resched();
901         }
902
903 done:
904         cb->args[1] = i;
905         cb->args[2] = num;
906 out:
907         ;
908 }
909 EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
910
911 static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
912                             const struct inet_diag_req_v2 *r,
913                             struct nlattr *bc)
914 {
915         const struct inet_diag_handler *handler;
916         int err = 0;
917
918         handler = inet_diag_lock_handler(r->sdiag_protocol);
919         if (!IS_ERR(handler))
920                 handler->dump(skb, cb, r, bc);
921         else
922                 err = PTR_ERR(handler);
923         inet_diag_unlock_handler(handler);
924
925         return err ? : skb->len;
926 }
927
928 static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
929 {
930         int hdrlen = sizeof(struct inet_diag_req_v2);
931         struct nlattr *bc = NULL;
932
933         if (nlmsg_attrlen(cb->nlh, hdrlen))
934                 bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
935
936         return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc);
937 }
938
939 static int inet_diag_type2proto(int type)
940 {
941         switch (type) {
942         case TCPDIAG_GETSOCK:
943                 return IPPROTO_TCP;
944         case DCCPDIAG_GETSOCK:
945                 return IPPROTO_DCCP;
946         default:
947                 return 0;
948         }
949 }
950
951 static int inet_diag_dump_compat(struct sk_buff *skb,
952                                  struct netlink_callback *cb)
953 {
954         struct inet_diag_req *rc = nlmsg_data(cb->nlh);
955         int hdrlen = sizeof(struct inet_diag_req);
956         struct inet_diag_req_v2 req;
957         struct nlattr *bc = NULL;
958
959         req.sdiag_family = AF_UNSPEC; /* compatibility */
960         req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
961         req.idiag_ext = rc->idiag_ext;
962         req.idiag_states = rc->idiag_states;
963         req.id = rc->id;
964
965         if (nlmsg_attrlen(cb->nlh, hdrlen))
966                 bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
967
968         return __inet_diag_dump(skb, cb, &req, bc);
969 }
970
971 static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
972                                       const struct nlmsghdr *nlh)
973 {
974         struct inet_diag_req *rc = nlmsg_data(nlh);
975         struct inet_diag_req_v2 req;
976
977         req.sdiag_family = rc->idiag_family;
978         req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
979         req.idiag_ext = rc->idiag_ext;
980         req.idiag_states = rc->idiag_states;
981         req.id = rc->id;
982
983         return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, &req);
984 }
985
986 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
987 {
988         int hdrlen = sizeof(struct inet_diag_req);
989         struct net *net = sock_net(skb->sk);
990
991         if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
992             nlmsg_len(nlh) < hdrlen)
993                 return -EINVAL;
994
995         if (nlh->nlmsg_flags & NLM_F_DUMP) {
996                 if (nlmsg_attrlen(nlh, hdrlen)) {
997                         struct nlattr *attr;
998
999                         attr = nlmsg_find_attr(nlh, hdrlen,
1000                                                INET_DIAG_REQ_BYTECODE);
1001                         if (!attr ||
1002                             nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
1003                             inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
1004                                 return -EINVAL;
1005                 }
1006                 {
1007                         struct netlink_dump_control c = {
1008                                 .dump = inet_diag_dump_compat,
1009                         };
1010                         return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1011                 }
1012         }
1013
1014         return inet_diag_get_exact_compat(skb, nlh);
1015 }
1016
1017 static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
1018 {
1019         int hdrlen = sizeof(struct inet_diag_req_v2);
1020         struct net *net = sock_net(skb->sk);
1021
1022         if (nlmsg_len(h) < hdrlen)
1023                 return -EINVAL;
1024
1025         if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
1026             h->nlmsg_flags & NLM_F_DUMP) {
1027                 if (nlmsg_attrlen(h, hdrlen)) {
1028                         struct nlattr *attr;
1029
1030                         attr = nlmsg_find_attr(h, hdrlen,
1031                                                INET_DIAG_REQ_BYTECODE);
1032                         if (!attr ||
1033                             nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
1034                             inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
1035                                 return -EINVAL;
1036                 }
1037                 {
1038                         struct netlink_dump_control c = {
1039                                 .dump = inet_diag_dump,
1040                         };
1041                         return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1042                 }
1043         }
1044
1045         return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
1046 }
1047
1048 static
1049 int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
1050 {
1051         const struct inet_diag_handler *handler;
1052         struct nlmsghdr *nlh;
1053         struct nlattr *attr;
1054         struct inet_diag_msg *r;
1055         void *info = NULL;
1056         int err = 0;
1057
1058         nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
1059         if (!nlh)
1060                 return -ENOMEM;
1061
1062         r = nlmsg_data(nlh);
1063         memset(r, 0, sizeof(*r));
1064         inet_diag_msg_common_fill(r, sk);
1065         if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
1066                 r->id.idiag_sport = inet_sk(sk)->inet_sport;
1067         r->idiag_state = sk->sk_state;
1068
1069         if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
1070                 nlmsg_cancel(skb, nlh);
1071                 return err;
1072         }
1073
1074         handler = inet_diag_lock_handler(sk->sk_protocol);
1075         if (IS_ERR(handler)) {
1076                 inet_diag_unlock_handler(handler);
1077                 nlmsg_cancel(skb, nlh);
1078                 return PTR_ERR(handler);
1079         }
1080
1081         attr = handler->idiag_info_size
1082                 ? nla_reserve_64bit(skb, INET_DIAG_INFO,
1083                                     handler->idiag_info_size,
1084                                     INET_DIAG_PAD)
1085                 : NULL;
1086         if (attr)
1087                 info = nla_data(attr);
1088
1089         handler->idiag_get_info(sk, r, info);
1090         inet_diag_unlock_handler(handler);
1091
1092         nlmsg_end(skb, nlh);
1093         return 0;
1094 }
1095
1096 static const struct sock_diag_handler inet_diag_handler = {
1097         .family = AF_INET,
1098         .dump = inet_diag_handler_cmd,
1099         .get_info = inet_diag_handler_get_info,
1100         .destroy = inet_diag_handler_cmd,
1101 };
1102
1103 static const struct sock_diag_handler inet6_diag_handler = {
1104         .family = AF_INET6,
1105         .dump = inet_diag_handler_cmd,
1106         .get_info = inet_diag_handler_get_info,
1107         .destroy = inet_diag_handler_cmd,
1108 };
1109
1110 int inet_diag_register(const struct inet_diag_handler *h)
1111 {
1112         const __u16 type = h->idiag_type;
1113         int err = -EINVAL;
1114
1115         if (type >= IPPROTO_MAX)
1116                 goto out;
1117
1118         mutex_lock(&inet_diag_table_mutex);
1119         err = -EEXIST;
1120         if (!inet_diag_table[type]) {
1121                 inet_diag_table[type] = h;
1122                 err = 0;
1123         }
1124         mutex_unlock(&inet_diag_table_mutex);
1125 out:
1126         return err;
1127 }
1128 EXPORT_SYMBOL_GPL(inet_diag_register);
1129
1130 void inet_diag_unregister(const struct inet_diag_handler *h)
1131 {
1132         const __u16 type = h->idiag_type;
1133
1134         if (type >= IPPROTO_MAX)
1135                 return;
1136
1137         mutex_lock(&inet_diag_table_mutex);
1138         inet_diag_table[type] = NULL;
1139         mutex_unlock(&inet_diag_table_mutex);
1140 }
1141 EXPORT_SYMBOL_GPL(inet_diag_unregister);
1142
1143 static int __init inet_diag_init(void)
1144 {
1145         const int inet_diag_table_size = (IPPROTO_MAX *
1146                                           sizeof(struct inet_diag_handler *));
1147         int err = -ENOMEM;
1148
1149         inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1150         if (!inet_diag_table)
1151                 goto out;
1152
1153         err = sock_diag_register(&inet_diag_handler);
1154         if (err)
1155                 goto out_free_nl;
1156
1157         err = sock_diag_register(&inet6_diag_handler);
1158         if (err)
1159                 goto out_free_inet;
1160
1161         sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1162 out:
1163         return err;
1164
1165 out_free_inet:
1166         sock_diag_unregister(&inet_diag_handler);
1167 out_free_nl:
1168         kfree(inet_diag_table);
1169         goto out;
1170 }
1171
1172 static void __exit inet_diag_exit(void)
1173 {
1174         sock_diag_unregister(&inet6_diag_handler);
1175         sock_diag_unregister(&inet_diag_handler);
1176         sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1177         kfree(inet_diag_table);
1178 }
1179
1180 module_init(inet_diag_init);
1181 module_exit(inet_diag_exit);
1182 MODULE_LICENSE("GPL");
1183 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1184 MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);