SUNRPC: Mark auth and cred operation tables as constant.
[cascardo/linux.git] / net / sunrpc / auth.c
index 55163af..d3f0f94 100644 (file)
@@ -18,7 +18,8 @@
 # define RPCDBG_FACILITY       RPCDBG_AUTH
 #endif
 
-static struct rpc_authops *    auth_flavors[RPC_AUTH_MAXFLAVOR] = {
+static DEFINE_SPINLOCK(rpc_authflavor_lock);
+static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
        &authnull_ops,          /* AUTH_NULL */
        &authunix_ops,          /* AUTH_UNIX */
        NULL,                   /* others can be loadable modules */
@@ -32,55 +33,67 @@ pseudoflavor_to_flavor(u32 flavor) {
 }
 
 int
-rpcauth_register(struct rpc_authops *ops)
+rpcauth_register(const struct rpc_authops *ops)
 {
        rpc_authflavor_t flavor;
+       int ret = -EPERM;
 
        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
                return -EINVAL;
-       if (auth_flavors[flavor] != NULL)
-               return -EPERM;          /* what else? */
-       auth_flavors[flavor] = ops;
-       return 0;
+       spin_lock(&rpc_authflavor_lock);
+       if (auth_flavors[flavor] == NULL) {
+               auth_flavors[flavor] = ops;
+               ret = 0;
+       }
+       spin_unlock(&rpc_authflavor_lock);
+       return ret;
 }
 
 int
-rpcauth_unregister(struct rpc_authops *ops)
+rpcauth_unregister(const struct rpc_authops *ops)
 {
        rpc_authflavor_t flavor;
+       int ret = -EPERM;
 
        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
                return -EINVAL;
-       if (auth_flavors[flavor] != ops)
-               return -EPERM;          /* what else? */
-       auth_flavors[flavor] = NULL;
-       return 0;
+       spin_lock(&rpc_authflavor_lock);
+       if (auth_flavors[flavor] == ops) {
+               auth_flavors[flavor] = NULL;
+               ret = 0;
+       }
+       spin_unlock(&rpc_authflavor_lock);
+       return ret;
 }
 
 struct rpc_auth *
 rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt)
 {
        struct rpc_auth         *auth;
-       struct rpc_authops      *ops;
+       const struct rpc_authops *ops;
        u32                     flavor = pseudoflavor_to_flavor(pseudoflavor);
 
        auth = ERR_PTR(-EINVAL);
        if (flavor >= RPC_AUTH_MAXFLAVOR)
                goto out;
 
-       /* FIXME - auth_flavors[] really needs an rw lock,
-        * and module refcounting. */
 #ifdef CONFIG_KMOD
        if ((ops = auth_flavors[flavor]) == NULL)
                request_module("rpc-auth-%u", flavor);
 #endif
-       if ((ops = auth_flavors[flavor]) == NULL)
+       spin_lock(&rpc_authflavor_lock);
+       ops = auth_flavors[flavor];
+       if (ops == NULL || !try_module_get(ops->owner)) {
+               spin_unlock(&rpc_authflavor_lock);
                goto out;
+       }
+       spin_unlock(&rpc_authflavor_lock);
        auth = ops->create(clnt, pseudoflavor);
+       module_put(ops->owner);
        if (IS_ERR(auth))
                return auth;
        if (clnt->cl_auth)
-               rpcauth_destroy(clnt->cl_auth);
+               rpcauth_release(clnt->cl_auth);
        clnt->cl_auth = auth;
 
 out:
@@ -88,7 +101,7 @@ out:
 }
 
 void
-rpcauth_destroy(struct rpc_auth *auth)
+rpcauth_release(struct rpc_auth *auth)
 {
        if (!atomic_dec_and_test(&auth->au_count))
                return;
@@ -137,9 +150,8 @@ void rpcauth_destroy_credlist(struct hlist_head *head)
  * that are not referenced.
  */
 void
-rpcauth_free_credcache(struct rpc_auth *auth)
+rpcauth_clear_credcache(struct rpc_cred_cache *cache)
 {
-       struct rpc_cred_cache *cache = auth->au_credcache;
        HLIST_HEAD(free);
        struct hlist_node *pos, *next;
        struct rpc_cred *cred;
@@ -157,6 +169,21 @@ rpcauth_free_credcache(struct rpc_auth *auth)
        rpcauth_destroy_credlist(&free);
 }
 
+/*
+ * Destroy the RPC credential cache
+ */
+void
+rpcauth_destroy_credcache(struct rpc_auth *auth)
+{
+       struct rpc_cred_cache *cache = auth->au_credcache;
+
+       if (cache) {
+               auth->au_credcache = NULL;
+               rpcauth_clear_credcache(cache);
+               kfree(cache);
+       }
+}
+
 static void
 rpcauth_prune_expired(struct rpc_auth *auth, struct rpc_cred *cred, struct hlist_head *free)
 {
@@ -181,7 +208,7 @@ rpcauth_gc_credcache(struct rpc_auth *auth, struct hlist_head *free)
        struct rpc_cred *cred;
        int             i;
 
-       dprintk("RPC: gc'ing RPC credentials for auth %p\n", auth);
+       dprintk("RPC:       gc'ing RPC credentials for auth %p\n", auth);
        for (i = 0; i < RPC_CREDCACHE_NR; i++) {
                hlist_for_each_safe(pos, next, &cache->hashtable[i]) {
                        cred = hlist_entry(pos, struct rpc_cred, cr_hash);
@@ -213,7 +240,7 @@ retry:
                rpcauth_gc_credcache(auth, &free);
        hlist_for_each_safe(pos, next, &cache->hashtable[nr]) {
                struct rpc_cred *entry;
-               entry = hlist_entry(pos, struct rpc_cred, cr_hash);
+               entry = hlist_entry(pos, struct rpc_cred, cr_hash);
                if (entry->cr_ops->crmatch(acred, entry, flags)) {
                        hlist_del(&entry->cr_hash);
                        cred = entry;
@@ -267,7 +294,7 @@ rpcauth_lookupcred(struct rpc_auth *auth, int flags)
        };
        struct rpc_cred *ret;
 
-       dprintk("RPC:     looking up %s cred\n",
+       dprintk("RPC:       looking up %s cred\n",
                auth->au_ops->au_name);
        get_group_info(acred.group_info);
        ret = auth->au_ops->lookup_cred(auth, &acred, flags);
@@ -287,7 +314,7 @@ rpcauth_bindcred(struct rpc_task *task)
        struct rpc_cred *ret;
        int flags = 0;
 
-       dprintk("RPC: %4d looking up %s cred\n",
+       dprintk("RPC: %5u looking up %s cred\n",
                task->tk_pid, task->tk_auth->au_ops->au_name);
        get_group_info(acred.group_info);
        if (task->tk_flags & RPC_TASK_ROOTCREDS)
@@ -304,8 +331,9 @@ rpcauth_bindcred(struct rpc_task *task)
 void
 rpcauth_holdcred(struct rpc_task *task)
 {
-       dprintk("RPC: %4d holding %s cred %p\n",
-               task->tk_pid, task->tk_auth->au_ops->au_name, task->tk_msg.rpc_cred);
+       dprintk("RPC: %5u holding %s cred %p\n",
+               task->tk_pid, task->tk_auth->au_ops->au_name,
+               task->tk_msg.rpc_cred);
        if (task->tk_msg.rpc_cred)
                get_rpccred(task->tk_msg.rpc_cred);
 }
@@ -324,30 +352,30 @@ rpcauth_unbindcred(struct rpc_task *task)
 {
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
 
-       dprintk("RPC: %4d releasing %s cred %p\n",
+       dprintk("RPC: %5u releasing %s cred %p\n",
                task->tk_pid, task->tk_auth->au_ops->au_name, cred);
 
        put_rpccred(cred);
        task->tk_msg.rpc_cred = NULL;
 }
 
-u32 *
-rpcauth_marshcred(struct rpc_task *task, u32 *p)
+__be32 *
+rpcauth_marshcred(struct rpc_task *task, __be32 *p)
 {
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
 
-       dprintk("RPC: %4d marshaling %s cred %p\n",
+       dprintk("RPC: %5u marshaling %s cred %p\n",
                task->tk_pid, task->tk_auth->au_ops->au_name, cred);
 
        return cred->cr_ops->crmarshal(task, p);
 }
 
-u32 *
-rpcauth_checkverf(struct rpc_task *task, u32 *p)
+__be32 *
+rpcauth_checkverf(struct rpc_task *task, __be32 *p)
 {
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
 
-       dprintk("RPC: %4d validating %s cred %p\n",
+       dprintk("RPC: %5u validating %s cred %p\n",
                task->tk_pid, task->tk_auth->au_ops->au_name, cred);
 
        return cred->cr_ops->crvalidate(task, p);
@@ -355,11 +383,11 @@ rpcauth_checkverf(struct rpc_task *task, u32 *p)
 
 int
 rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp,
-               u32 *data, void *obj)
+               __be32 *data, void *obj)
 {
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
 
-       dprintk("RPC: %4d using %s cred %p to wrap rpc data\n",
+       dprintk("RPC: %5u using %s cred %p to wrap rpc data\n",
                        task->tk_pid, cred->cr_ops->cr_name, cred);
        if (cred->cr_ops->crwrap_req)
                return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
@@ -369,11 +397,11 @@ rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp,
 
 int
 rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp,
-               u32 *data, void *obj)
+               __be32 *data, void *obj)
 {
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
 
-       dprintk("RPC: %4d using %s cred %p to unwrap rpc data\n",
+       dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n",
                        task->tk_pid, cred->cr_ops->cr_name, cred);
        if (cred->cr_ops->crunwrap_resp)
                return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
@@ -388,7 +416,7 @@ rpcauth_refreshcred(struct rpc_task *task)
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
        int err;
 
-       dprintk("RPC: %4d refreshing %s cred %p\n",
+       dprintk("RPC: %5u refreshing %s cred %p\n",
                task->tk_pid, task->tk_auth->au_ops->au_name, cred);
 
        err = cred->cr_ops->crrefresh(task);
@@ -400,7 +428,7 @@ rpcauth_refreshcred(struct rpc_task *task)
 void
 rpcauth_invalcred(struct rpc_task *task)
 {
-       dprintk("RPC: %4d invalidating %s cred %p\n",
+       dprintk("RPC: %5u invalidating %s cred %p\n",
                task->tk_pid, task->tk_auth->au_ops->au_name, task->tk_msg.rpc_cred);
        spin_lock(&rpc_credcache_lock);
        if (task->tk_msg.rpc_cred)