net: tcp: assign tcp cong_ops when tcp sk is created
[cascardo/linux.git] / net / ipv4 / tcp_cong.c
index 7b09d8b..a6c8a57 100644 (file)
@@ -74,24 +74,34 @@ void tcp_unregister_congestion_control(struct tcp_congestion_ops *ca)
 EXPORT_SYMBOL_GPL(tcp_unregister_congestion_control);
 
 /* Assign choice of congestion control. */
-void tcp_init_congestion_control(struct sock *sk)
+void tcp_assign_congestion_control(struct sock *sk)
 {
        struct inet_connection_sock *icsk = inet_csk(sk);
        struct tcp_congestion_ops *ca;
 
-       /* if no choice made yet assign the current value set as default */
-       if (icsk->icsk_ca_ops == &tcp_init_congestion_ops) {
-               rcu_read_lock();
-               list_for_each_entry_rcu(ca, &tcp_cong_list, list) {
-                       if (try_module_get(ca->owner)) {
-                               icsk->icsk_ca_ops = ca;
-                               break;
-                       }
-
-                       /* fallback to next available */
+       rcu_read_lock();
+       list_for_each_entry_rcu(ca, &tcp_cong_list, list) {
+               if (likely(try_module_get(ca->owner))) {
+                       icsk->icsk_ca_ops = ca;
+                       goto out;
                }
-               rcu_read_unlock();
+               /* Fallback to next available. The last really
+                * guaranteed fallback is Reno from this list.
+                */
        }
+out:
+       rcu_read_unlock();
+
+       /* Clear out private data before diag gets it and
+        * the ca has not been initialized.
+        */
+       if (ca->get_info)
+               memset(icsk->icsk_ca_priv, 0, sizeof(icsk->icsk_ca_priv));
+}
+
+void tcp_init_congestion_control(struct sock *sk)
+{
+       const struct inet_connection_sock *icsk = inet_csk(sk);
 
        if (icsk->icsk_ca_ops->init)
                icsk->icsk_ca_ops->init(sk);
@@ -142,7 +152,6 @@ static int __init tcp_congestion_default(void)
 }
 late_initcall(tcp_congestion_default);
 
-
 /* Build string with list of available congestion control values */
 void tcp_get_available_congestion_control(char *buf, size_t maxlen)
 {
@@ -154,7 +163,6 @@ void tcp_get_available_congestion_control(char *buf, size_t maxlen)
                offs += snprintf(buf + offs, maxlen - offs,
                                 "%s%s",
                                 offs == 0 ? "" : " ", ca->name);
-
        }
        rcu_read_unlock();
 }
@@ -186,7 +194,6 @@ void tcp_get_allowed_congestion_control(char *buf, size_t maxlen)
                offs += snprintf(buf + offs, maxlen - offs,
                                 "%s%s",
                                 offs == 0 ? "" : " ", ca->name);
-
        }
        rcu_read_unlock();
 }
@@ -230,7 +237,6 @@ out:
        return ret;
 }
 
-
 /* Change congestion control for socket */
 int tcp_set_congestion_control(struct sock *sk, const char *name)
 {
@@ -337,6 +343,7 @@ EXPORT_SYMBOL_GPL(tcp_reno_cong_avoid);
 u32 tcp_reno_ssthresh(struct sock *sk)
 {
        const struct tcp_sock *tp = tcp_sk(sk);
+
        return max(tp->snd_cwnd >> 1U, 2U);
 }
 EXPORT_SYMBOL_GPL(tcp_reno_ssthresh);
@@ -348,15 +355,3 @@ struct tcp_congestion_ops tcp_reno = {
        .ssthresh       = tcp_reno_ssthresh,
        .cong_avoid     = tcp_reno_cong_avoid,
 };
-
-/* Initial congestion control used (until SYN)
- * really reno under another name so we can tell difference
- * during tcp_set_default_congestion_control
- */
-struct tcp_congestion_ops tcp_init_congestion_ops  = {
-       .name           = "",
-       .owner          = THIS_MODULE,
-       .ssthresh       = tcp_reno_ssthresh,
-       .cong_avoid     = tcp_reno_cong_avoid,
-};
-EXPORT_SYMBOL_GPL(tcp_init_congestion_ops);