xfrm: SP lookups with mark
[cascardo/linux.git] / net / xfrm / xfrm_policy.c
index 4725a54..2a6e646 100644 (file)
@@ -469,16 +469,16 @@ static inline int xfrm_byidx_should_resize(struct net *net, int total)
        return 0;
 }
 
-void xfrm_spd_getinfo(struct xfrmk_spdinfo *si)
+void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si)
 {
        read_lock_bh(&xfrm_policy_lock);
-       si->incnt = init_net.xfrm.policy_count[XFRM_POLICY_IN];
-       si->outcnt = init_net.xfrm.policy_count[XFRM_POLICY_OUT];
-       si->fwdcnt = init_net.xfrm.policy_count[XFRM_POLICY_FWD];
-       si->inscnt = init_net.xfrm.policy_count[XFRM_POLICY_IN+XFRM_POLICY_MAX];
-       si->outscnt = init_net.xfrm.policy_count[XFRM_POLICY_OUT+XFRM_POLICY_MAX];
-       si->fwdscnt = init_net.xfrm.policy_count[XFRM_POLICY_FWD+XFRM_POLICY_MAX];
-       si->spdhcnt = init_net.xfrm.policy_idx_hmask;
+       si->incnt = net->xfrm.policy_count[XFRM_POLICY_IN];
+       si->outcnt = net->xfrm.policy_count[XFRM_POLICY_OUT];
+       si->fwdcnt = net->xfrm.policy_count[XFRM_POLICY_FWD];
+       si->inscnt = net->xfrm.policy_count[XFRM_POLICY_IN+XFRM_POLICY_MAX];
+       si->outscnt = net->xfrm.policy_count[XFRM_POLICY_OUT+XFRM_POLICY_MAX];
+       si->fwdscnt = net->xfrm.policy_count[XFRM_POLICY_FWD+XFRM_POLICY_MAX];
+       si->spdhcnt = net->xfrm.policy_idx_hmask;
        si->spdhmcnt = xfrm_policy_hashmax;
        read_unlock_bh(&xfrm_policy_lock);
 }
@@ -556,6 +556,7 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl)
        struct hlist_head *chain;
        struct hlist_node *entry, *newpos;
        struct dst_entry *gc_list;
+       u32 mark = policy->mark.v & policy->mark.m;
 
        write_lock_bh(&xfrm_policy_lock);
        chain = policy_hash_bysel(net, &policy->selector, policy->family, dir);
@@ -564,6 +565,7 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl)
        hlist_for_each_entry(pol, entry, chain, bydst) {
                if (pol->type == policy->type &&
                    !selector_cmp(&pol->selector, &policy->selector) &&
+                   (mark & pol->mark.m) == pol->mark.v &&
                    xfrm_sec_ctx_match(pol->security, policy->security) &&
                    !WARN_ON(delpol)) {
                        if (excl) {
@@ -635,8 +637,8 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl)
 }
 EXPORT_SYMBOL(xfrm_policy_insert);
 
-struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u8 type, int dir,
-                                         struct xfrm_selector *sel,
+struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u32 mark, u8 type,
+                                         int dir, struct xfrm_selector *sel,
                                          struct xfrm_sec_ctx *ctx, int delete,
                                          int *err)
 {
@@ -650,6 +652,7 @@ struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u8 type, int dir,
        ret = NULL;
        hlist_for_each_entry(pol, entry, chain, bydst) {
                if (pol->type == type &&
+                   (mark & pol->mark.m) == pol->mark.v &&
                    !selector_cmp(sel, &pol->selector) &&
                    xfrm_sec_ctx_match(ctx, pol->security)) {
                        xfrm_pol_hold(pol);
@@ -676,8 +679,8 @@ struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u8 type, int dir,
 }
 EXPORT_SYMBOL(xfrm_policy_bysel_ctx);
 
-struct xfrm_policy *xfrm_policy_byid(struct net *net, u8 type, int dir, u32 id,
-                                    int delete, int *err)
+struct xfrm_policy *xfrm_policy_byid(struct net *net, u32 mark, u8 type,
+                                    int dir, u32 id, int delete, int *err)
 {
        struct xfrm_policy *pol, *ret;
        struct hlist_head *chain;
@@ -692,7 +695,8 @@ struct xfrm_policy *xfrm_policy_byid(struct net *net, u8 type, int dir, u32 id,
        chain = net->xfrm.policy_byidx + idx_hash(net, id);
        ret = NULL;
        hlist_for_each_entry(pol, entry, chain, byidx) {
-               if (pol->type == type && pol->index == id) {
+               if (pol->type == type && pol->index == id &&
+                   (mark & pol->mark.m) == pol->mark.v) {
                        xfrm_pol_hold(pol);
                        if (delete) {
                                *err = security_xfrm_policy_delete(
@@ -771,7 +775,8 @@ xfrm_policy_flush_secctx_check(struct net *net, u8 type, struct xfrm_audit *audi
 
 int xfrm_policy_flush(struct net *net, u8 type, struct xfrm_audit *audit_info)
 {
-       int dir, err = 0;
+       int dir, err = 0, cnt = 0;
+       struct xfrm_policy *dp;
 
        write_lock_bh(&xfrm_policy_lock);
 
@@ -789,8 +794,10 @@ int xfrm_policy_flush(struct net *net, u8 type, struct xfrm_audit *audit_info)
                                     &net->xfrm.policy_inexact[dir], bydst) {
                        if (pol->type != type)
                                continue;
-                       __xfrm_policy_unlink(pol, dir);
+                       dp = __xfrm_policy_unlink(pol, dir);
                        write_unlock_bh(&xfrm_policy_lock);
+                       if (dp)
+                               cnt++;
 
                        xfrm_audit_policy_delete(pol, 1, audit_info->loginuid,
                                                 audit_info->sessionid,
@@ -809,8 +816,10 @@ int xfrm_policy_flush(struct net *net, u8 type, struct xfrm_audit *audit_info)
                                             bydst) {
                                if (pol->type != type)
                                        continue;
-                               __xfrm_policy_unlink(pol, dir);
+                               dp = __xfrm_policy_unlink(pol, dir);
                                write_unlock_bh(&xfrm_policy_lock);
+                               if (dp)
+                                       cnt++;
 
                                xfrm_audit_policy_delete(pol, 1,
                                                         audit_info->loginuid,
@@ -824,6 +833,8 @@ int xfrm_policy_flush(struct net *net, u8 type, struct xfrm_audit *audit_info)
                }
 
        }
+       if (!cnt)
+               err = -ESRCH;
        atomic_inc(&flow_cache_genid);
 out:
        write_unlock_bh(&xfrm_policy_lock);
@@ -909,6 +920,7 @@ static int xfrm_policy_match(struct xfrm_policy *pol, struct flowi *fl,
        int match, ret = -ESRCH;
 
        if (pol->family != family ||
+           (fl->mark & pol->mark.m) != pol->mark.v ||
            pol->type != type)
                return ret;
 
@@ -1033,6 +1045,10 @@ static struct xfrm_policy *xfrm_sk_policy_lookup(struct sock *sk, int dir, struc
                int err = 0;
 
                if (match) {
+                       if ((sk->sk_mark & pol->mark.m) != pol->mark.v) {
+                               pol = NULL;
+                               goto out;
+                       }
                        err = security_xfrm_policy_lookup(pol->security,
                                                      fl->secid,
                                                      policy_to_flow_dir(dir));
@@ -1045,6 +1061,7 @@ static struct xfrm_policy *xfrm_sk_policy_lookup(struct sock *sk, int dir, struc
                } else
                        pol = NULL;
        }
+out:
        read_unlock_bh(&xfrm_policy_lock);
        return pol;
 }
@@ -1309,15 +1326,28 @@ static inline int xfrm_get_tos(struct flowi *fl, int family)
        return tos;
 }
 
-static inline struct xfrm_dst *xfrm_alloc_dst(int family)
+static inline struct xfrm_dst *xfrm_alloc_dst(struct net *net, int family)
 {
        struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family);
+       struct dst_ops *dst_ops;
        struct xfrm_dst *xdst;
 
        if (!afinfo)
                return ERR_PTR(-EINVAL);
 
-       xdst = dst_alloc(afinfo->dst_ops) ?: ERR_PTR(-ENOBUFS);
+       switch (family) {
+       case AF_INET:
+               dst_ops = &net->xfrm.xfrm4_dst_ops;
+               break;
+#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
+       case AF_INET6:
+               dst_ops = &net->xfrm.xfrm6_dst_ops;
+               break;
+#endif
+       default:
+               BUG();
+       }
+       xdst = dst_alloc(dst_ops) ?: ERR_PTR(-ENOBUFS);
 
        xfrm_policy_put_afinfo(afinfo);
 
@@ -1366,6 +1396,7 @@ static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy,
                                            struct flowi *fl,
                                            struct dst_entry *dst)
 {
+       struct net *net = xp_net(policy);
        unsigned long now = jiffies;
        struct net_device *dev;
        struct dst_entry *dst_prev = NULL;
@@ -1389,7 +1420,7 @@ static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy,
        dst_hold(dst);
 
        for (; i < nx; i++) {
-               struct xfrm_dst *xdst = xfrm_alloc_dst(family);
+               struct xfrm_dst *xdst = xfrm_alloc_dst(net, family);
                struct dst_entry *dst1 = &xdst->u.dst;
 
                err = PTR_ERR(xdst);
@@ -2031,8 +2062,7 @@ int __xfrm_route_forward(struct sk_buff *skb, unsigned short family)
        int res;
 
        if (xfrm_decode_session(skb, &fl, family) < 0) {
-               /* XXX: we should have something like FWDHDRERROR here. */
-               XFRM_INC_STATS(net, LINUX_MIB_XFRMINHDRERROR);
+               XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR);
                return 0;
        }
 
@@ -2279,6 +2309,7 @@ EXPORT_SYMBOL(xfrm_bundle_ok);
 
 int xfrm_policy_register_afinfo(struct xfrm_policy_afinfo *afinfo)
 {
+       struct net *net;
        int err = 0;
        if (unlikely(afinfo == NULL))
                return -EINVAL;
@@ -2302,6 +2333,27 @@ int xfrm_policy_register_afinfo(struct xfrm_policy_afinfo *afinfo)
                xfrm_policy_afinfo[afinfo->family] = afinfo;
        }
        write_unlock_bh(&xfrm_policy_afinfo_lock);
+
+       rtnl_lock();
+       for_each_net(net) {
+               struct dst_ops *xfrm_dst_ops;
+
+               switch (afinfo->family) {
+               case AF_INET:
+                       xfrm_dst_ops = &net->xfrm.xfrm4_dst_ops;
+                       break;
+#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
+               case AF_INET6:
+                       xfrm_dst_ops = &net->xfrm.xfrm6_dst_ops;
+                       break;
+#endif
+               default:
+                       BUG();
+               }
+               *xfrm_dst_ops = *afinfo->dst_ops;
+       }
+       rtnl_unlock();
+
        return err;
 }
 EXPORT_SYMBOL(xfrm_policy_register_afinfo);
@@ -2332,6 +2384,22 @@ int xfrm_policy_unregister_afinfo(struct xfrm_policy_afinfo *afinfo)
 }
 EXPORT_SYMBOL(xfrm_policy_unregister_afinfo);
 
+static void __net_init xfrm_dst_ops_init(struct net *net)
+{
+       struct xfrm_policy_afinfo *afinfo;
+
+       read_lock_bh(&xfrm_policy_afinfo_lock);
+       afinfo = xfrm_policy_afinfo[AF_INET];
+       if (afinfo)
+               net->xfrm.xfrm4_dst_ops = *afinfo->dst_ops;
+#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
+       afinfo = xfrm_policy_afinfo[AF_INET6];
+       if (afinfo)
+               net->xfrm.xfrm6_dst_ops = *afinfo->dst_ops;
+#endif
+       read_unlock_bh(&xfrm_policy_afinfo_lock);
+}
+
 static struct xfrm_policy_afinfo *xfrm_policy_get_afinfo(unsigned short family)
 {
        struct xfrm_policy_afinfo *afinfo;
@@ -2369,19 +2437,19 @@ static int __net_init xfrm_statistics_init(struct net *net)
 {
        int rv;
 
-       if (snmp_mib_init((void **)net->mib.xfrm_statistics,
+       if (snmp_mib_init((void __percpu **)net->mib.xfrm_statistics,
                          sizeof(struct linux_xfrm_mib)) < 0)
                return -ENOMEM;
        rv = xfrm_proc_init(net);
        if (rv < 0)
-               snmp_mib_free((void **)net->mib.xfrm_statistics);
+               snmp_mib_free((void __percpu **)net->mib.xfrm_statistics);
        return rv;
 }
 
 static void xfrm_statistics_fini(struct net *net)
 {
        xfrm_proc_fini(net);
-       snmp_mib_free((void **)net->mib.xfrm_statistics);
+       snmp_mib_free((void __percpu **)net->mib.xfrm_statistics);
 }
 #else
 static int __net_init xfrm_statistics_init(struct net *net)
@@ -2494,6 +2562,7 @@ static int __net_init xfrm_net_init(struct net *net)
        rv = xfrm_policy_init(net);
        if (rv < 0)
                goto out_policy;
+       xfrm_dst_ops_init(net);
        rv = xfrm_sysctl_init(net);
        if (rv < 0)
                goto out_sysctl;