inet: refactor inet[6]_lookup functions to take skb
[cascardo/linux.git] / net / ipv6 / inet6_hashtables.c
index 21ace5a..004345d 100644 (file)
@@ -121,7 +121,9 @@ static inline int compute_score(struct sock *sk, struct net *net,
 }
 
 struct sock *inet6_lookup_listener(struct net *net,
-               struct inet_hashinfo *hashinfo, const struct in6_addr *saddr,
+               struct inet_hashinfo *hashinfo,
+               struct sk_buff *skb, int doff,
+               const struct in6_addr *saddr,
                const __be16 sport, const struct in6_addr *daddr,
                const unsigned short hnum, const int dif)
 {
@@ -177,6 +179,7 @@ begin:
 EXPORT_SYMBOL_GPL(inet6_lookup_listener);
 
 struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
+                         struct sk_buff *skb, int doff,
                          const struct in6_addr *saddr, const __be16 sport,
                          const struct in6_addr *daddr, const __be16 dport,
                          const int dif)
@@ -184,7 +187,8 @@ struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
        struct sock *sk;
 
        local_bh_disable();
-       sk = __inet6_lookup(net, hashinfo, saddr, sport, daddr, ntohs(dport), dif);
+       sk = __inet6_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
+                           ntohs(dport), dif);
        local_bh_enable();
 
        return sk;
@@ -274,3 +278,59 @@ int inet6_hash_connect(struct inet_timewait_death_row *death_row,
                                   __inet6_check_established);
 }
 EXPORT_SYMBOL_GPL(inet6_hash_connect);
+
+int inet6_hash(struct sock *sk)
+{
+       if (sk->sk_state != TCP_CLOSE) {
+               local_bh_disable();
+               __inet_hash(sk, NULL);
+               local_bh_enable();
+       }
+
+       return 0;
+}
+EXPORT_SYMBOL_GPL(inet6_hash);
+
+/* match_wildcard == true:  IPV6_ADDR_ANY equals to any IPv6 addresses if IPv6
+ *                          only, and any IPv4 addresses if not IPv6 only
+ * match_wildcard == false: addresses must be exactly the same, i.e.
+ *                          IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY,
+ *                          and 0.0.0.0 equals to 0.0.0.0 only
+ */
+int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
+                        bool match_wildcard)
+{
+       const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2);
+       int sk2_ipv6only = inet_v6_ipv6only(sk2);
+       int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr);
+       int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED;
+
+       /* if both are mapped, treat as IPv4 */
+       if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) {
+               if (!sk2_ipv6only) {
+                       if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr)
+                               return 1;
+                       if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr)
+                               return match_wildcard;
+               }
+               return 0;
+       }
+
+       if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY)
+               return 1;
+
+       if (addr_type2 == IPV6_ADDR_ANY && match_wildcard &&
+           !(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED))
+               return 1;
+
+       if (addr_type == IPV6_ADDR_ANY && match_wildcard &&
+           !(ipv6_only_sock(sk) && addr_type2 == IPV6_ADDR_MAPPED))
+               return 1;
+
+       if (sk2_rcv_saddr6 &&
+           ipv6_addr_equal(&sk->sk_v6_rcv_saddr, sk2_rcv_saddr6))
+               return 1;
+
+       return 0;
+}
+EXPORT_SYMBOL_GPL(ipv6_rcv_saddr_equal);