inet: create IPv6-equivalent inet_hash function
[cascardo/linux.git] / net / ipv6 / tcp_ipv6.c
index c5429a6..d72bcfb 100644 (file)
@@ -61,7 +61,6 @@
 #include <net/timewait_sock.h>
 #include <net/inet_common.h>
 #include <net/secure_seq.h>
-#include <net/tcp_memcontrol.h>
 #include <net/busy_poll.h>
 
 #include <linux/proc_fs.h>
@@ -93,10 +92,9 @@ static void inet6_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb)
 {
        struct dst_entry *dst = skb_dst(skb);
 
-       if (dst) {
+       if (dst && dst_hold_safe(dst)) {
                const struct rt6_info *rt = (const struct rt6_info *)dst;
 
-               dst_hold(dst);
                sk->sk_rx_dst = dst;
                inet_sk(sk)->rx_dst_ifindex = skb->skb_iif;
                inet6_sk(sk)->rx_dst_cookie = rt6_get_cookie(rt);
@@ -120,6 +118,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
        struct ipv6_pinfo *np = inet6_sk(sk);
        struct tcp_sock *tp = tcp_sk(sk);
        struct in6_addr *saddr = NULL, *final_p, final;
+       struct ipv6_txoptions *opt;
        struct flowi6 fl6;
        struct dst_entry *dst;
        int addr_type;
@@ -235,7 +234,8 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
        fl6.fl6_dport = usin->sin6_port;
        fl6.fl6_sport = inet->inet_sport;
 
-       final_p = fl6_update_dst(&fl6, np->opt, &final);
+       opt = rcu_dereference_protected(np->opt, sock_owned_by_user(sk));
+       final_p = fl6_update_dst(&fl6, opt, &final);
 
        security_sk_classify_flow(sk, flowi6_to_flowi(&fl6));
 
@@ -255,7 +255,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
        inet->inet_rcv_saddr = LOOPBACK4_IPV6;
 
        sk->sk_gso_type = SKB_GSO_TCPV6;
-       __ip6_dst_store(sk, dst, NULL, NULL);
+       ip6_dst_store(sk, dst, NULL, NULL);
 
        if (tcp_death_row.sysctl_tw_recycle &&
            !tp->rx_opt.ts_recent_stamp &&
@@ -263,9 +263,9 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr,
                tcp_fetch_timewait_stamp(sk, dst);
 
        icsk->icsk_ext_hdr_len = 0;
-       if (np->opt)
-               icsk->icsk_ext_hdr_len = (np->opt->opt_flen +
-                                         np->opt->opt_nflen);
+       if (opt)
+               icsk->icsk_ext_hdr_len = opt->opt_flen +
+                                        opt->opt_nflen;
 
        tp->rx_opt.mss_clamp = IPV6_MIN_MTU - sizeof(struct tcphdr) - sizeof(struct ipv6hdr);
 
@@ -461,7 +461,10 @@ static int tcp_v6_send_synack(const struct sock *sk, struct dst_entry *dst,
                if (np->repflow && ireq->pktopts)
                        fl6->flowlabel = ip6_flowlabel(ipv6_hdr(ireq->pktopts));
 
-               err = ip6_xmit(sk, skb, fl6, np->opt, np->tclass);
+               rcu_read_lock();
+               err = ip6_xmit(sk, skb, fl6, rcu_dereference(np->opt),
+                              np->tclass);
+               rcu_read_unlock();
                err = net_xmit_eval(err);
        }
 
@@ -852,7 +855,9 @@ static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb)
 
 #ifdef CONFIG_TCP_MD5SIG
        hash_location = tcp_parse_md5sig_option(th);
-       if (!sk && hash_location) {
+       if (sk && sk_fullsock(sk)) {
+               key = tcp_v6_md5_do_lookup(sk, &ipv6h->saddr);
+       } else if (hash_location) {
                /*
                 * active side is lost. Try to find listening socket through
                 * source port, and then find md5 key through listening socket.
@@ -875,8 +880,6 @@ static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb)
                genhash = tcp_v6_md5_hash_skb(newhash, key, NULL, skb);
                if (genhash || memcmp(hash_location, newhash, 16) != 0)
                        goto release_sk1;
-       } else {
-               key = sk ? tcp_v6_md5_do_lookup(sk, &ipv6h->saddr) : NULL;
        }
 #endif
 
@@ -972,6 +975,7 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
        struct inet_request_sock *ireq;
        struct ipv6_pinfo *newnp;
        const struct ipv6_pinfo *np = inet6_sk(sk);
+       struct ipv6_txoptions *opt;
        struct tcp6_sock *newtcp6sk;
        struct inet_sock *newinet;
        struct tcp_sock *newtp;
@@ -1056,7 +1060,7 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
         */
 
        newsk->sk_gso_type = SKB_GSO_TCPV6;
-       __ip6_dst_store(newsk, dst, NULL, NULL);
+       ip6_dst_store(newsk, dst, NULL, NULL);
        inet6_sk_rx_dst_set(newsk, skb);
 
        newtcp6sk = (struct tcp6_sock *)newsk;
@@ -1098,13 +1102,15 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
           but we make one more one thing there: reattach optmem
           to newsk.
         */
-       if (np->opt)
-               newnp->opt = ipv6_dup_options(newsk, np->opt);
-
+       opt = rcu_dereference(np->opt);
+       if (opt) {
+               opt = ipv6_dup_options(newsk, opt);
+               RCU_INIT_POINTER(newnp->opt, opt);
+       }
        inet_csk(newsk)->icsk_ext_hdr_len = 0;
-       if (newnp->opt)
-               inet_csk(newsk)->icsk_ext_hdr_len = (newnp->opt->opt_nflen +
-                                                    newnp->opt->opt_flen);
+       if (opt)
+               inet_csk(newsk)->icsk_ext_hdr_len = opt->opt_nflen +
+                                                   opt->opt_flen;
 
        tcp_ca_openreq_child(newsk, dst);
 
@@ -1130,7 +1136,7 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
                 */
                tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
                               AF_INET6, key->key, key->keylen,
-                              sk_gfp_atomic(sk, GFP_ATOMIC));
+                              sk_gfp_mask(sk, GFP_ATOMIC));
        }
 #endif
 
@@ -1146,7 +1152,7 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
                /* Clone pktoptions received with SYN, if we own the req */
                if (ireq->pktopts) {
                        newnp->pktoptions = skb_clone(ireq->pktopts,
-                                                     sk_gfp_atomic(sk, GFP_ATOMIC));
+                                                     sk_gfp_mask(sk, GFP_ATOMIC));
                        consume_skb(ireq->pktopts);
                        ireq->pktopts = NULL;
                        if (newnp->pktoptions)
@@ -1212,7 +1218,7 @@ static int tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb)
                                               --ANK (980728)
         */
        if (np->rxopt.all)
-               opt_skb = skb_clone(skb, sk_gfp_atomic(sk, GFP_ATOMIC));
+               opt_skb = skb_clone(skb, sk_gfp_mask(sk, GFP_ATOMIC));
 
        if (sk->sk_state == TCP_ESTABLISHED) { /* Fast path */
                struct dst_entry *dst = sk->sk_rx_dst;
@@ -1511,7 +1517,9 @@ do_time_wait:
                break;
        case TCP_TW_RST:
                tcp_v6_restore_cb(skb);
-               goto no_tcp_socket;
+               tcp_v6_send_reset(sk, skb);
+               inet_twsk_deschedule_put(inet_twsk(sk));
+               goto discard_it;
        case TCP_TW_SUCCESS:
                ;
        }
@@ -1857,7 +1865,7 @@ struct proto tcpv6_prot = {
        .sendpage               = tcp_sendpage,
        .backlog_rcv            = tcp_v6_do_rcv,
        .release_cb             = tcp_release_cb,
-       .hash                   = inet_hash,
+       .hash                   = inet6_hash,
        .unhash                 = inet_unhash,
        .get_port               = inet_csk_get_port,
        .enter_memory_pressure  = tcp_enter_memory_pressure,
@@ -1879,11 +1887,9 @@ struct proto tcpv6_prot = {
 #ifdef CONFIG_COMPAT
        .compat_setsockopt      = compat_tcp_setsockopt,
        .compat_getsockopt      = compat_tcp_getsockopt,
-#endif
-#ifdef CONFIG_MEMCG_KMEM
-       .proto_cgroup           = tcp_proto_cgroup,
 #endif
        .clear_sk               = tcp_v6_clear_sk,
+       .diag_destroy           = tcp_abort,
 };
 
 static const struct inet6_protocol tcpv6_protocol = {