netdev-dpdk: fix mbuf leaks
[cascardo/ovs.git] / lib / stream-ssl.c
index ddf388f..2699633 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2008, 2009, 2010, 2011, 2012, 2013, 2014 Nicira, Inc.
+ * Copyright (c) 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016 Nicira, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -45,7 +45,7 @@
 #include "stream-provider.h"
 #include "stream.h"
 #include "timeval.h"
-#include "vlog.h"
+#include "openvswitch/vlog.h"
 
 #ifdef _WIN32
 /* Ref: https://www.openssl.org/support/faq.html#PROG2
@@ -60,8 +60,6 @@
  * compiled with /MD is not tested. */
 #include <openssl/applink.c>
 #define SHUT_RDWR SD_BOTH
-#else
-#define closesocket close
 #endif
 
 VLOG_DEFINE_THIS_MODULE(stream_ssl);
@@ -84,7 +82,6 @@ struct ssl_stream
     enum ssl_state state;
     enum session_type type;
     int fd;
-    HANDLE wevent;
     SSL *ssl;
     struct ofpbuf *txbuf;
     unsigned int session_nr;
@@ -202,7 +199,6 @@ static void ssl_protocol_cb(int write_p, int version, int content_type,
                             const void *, size_t, SSL *, void *sslv_);
 static bool update_ssl_config(struct ssl_config_file *, const char *file_name);
 static int sock_errno(void);
-static void clear_handle(int fd, HANDLE wevent);
 
 static short int
 want_to_poll_events(int want)
@@ -226,11 +222,8 @@ static int
 new_ssl_stream(const char *name, int fd, enum session_type type,
                enum ssl_state state, struct stream **streamp)
 {
-    struct sockaddr_storage local;
-    socklen_t local_len = sizeof local;
     struct ssl_stream *sslv;
     SSL *ssl = NULL;
-    int on = 1;
     int retval;
 
     /* Check for all the needful configuration. */
@@ -256,19 +249,11 @@ new_ssl_stream(const char *name, int fd, enum session_type type,
         goto error;
     }
 
-    /* Get the local IP and port information */
-    retval = getsockname(fd, (struct sockaddr *) &local, &local_len);
-    if (retval) {
-        memset(&local, 0, sizeof local);
-    }
-
-    /* Disable Nagle. */
-    retval = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on, sizeof on);
-    if (retval) {
-        retval = sock_errno();
-        VLOG_ERR("%s: setsockopt(TCP_NODELAY): %s", name,
-                 sock_strerror(retval));
-        goto error;
+    /* Disable Nagle.
+     * On windows platforms, this can only be called upon TCP connected.
+     */
+    if (state == STATE_SSL_CONNECTING) {
+        setsockopt_tcp_nodelay(fd);
     }
 
     /* Create and configure OpenSSL stream. */
@@ -293,9 +278,6 @@ new_ssl_stream(const char *name, int fd, enum session_type type,
     sslv->state = state;
     sslv->type = type;
     sslv->fd = fd;
-#ifdef _WIN32
-    sslv->wevent = CreateEvent(NULL, FALSE, FALSE, NULL);
-#endif
     sslv->ssl = ssl;
     sslv->txbuf = NULL;
     sslv->rx_want = sslv->tx_want = SSL_NOTHING;
@@ -335,7 +317,7 @@ ssl_open(const char *name, char *suffix, struct stream **streamp, uint8_t dscp)
         return error;
     }
 
-    error = inet_open_active(SOCK_STREAM, suffix, OFP_OLD_PORT, NULL, &fd,
+    error = inet_open_active(SOCK_STREAM, suffix, OFP_PORT, NULL, &fd,
                              dscp);
     if (fd >= 0) {
         int state = error ? STATE_TCP_CONNECTING : STATE_SSL_CONNECTING;
@@ -426,12 +408,6 @@ do_ca_cert_bootstrap(struct stream *stream)
     /* SSL_CTX_add_client_CA makes a copy of cert's relevant data. */
     SSL_CTX_add_client_CA(ctx, cert);
 
-    /* SSL_CTX_use_certificate() takes ownership of the certificate passed in.
-     * 'cert' is owned by sslv->ssl, so we need to duplicate it. */
-    cert = X509_dup(cert);
-    if (!cert) {
-        out_of_memory();
-    }
     SSL_CTX_set_cert_store(ctx, X509_STORE_new());
     if (SSL_CTX_load_verify_locations(ctx, ca_cert.file_name, NULL) != 1) {
         VLOG_ERR("SSL_CTX_load_verify_locations: %s",
@@ -455,6 +431,7 @@ ssl_connect(struct stream *stream)
             return retval;
         }
         sslv->state = STATE_SSL_CONNECTING;
+        setsockopt_tcp_nodelay(sslv->fd);
         /* Fall through. */
 
     case STATE_SSL_CONNECTING:
@@ -478,7 +455,7 @@ ssl_connect(struct stream *stream)
                                      : "SSL_accept"), retval, error, &unused);
                 shutdown(sslv->fd, SHUT_RDWR);
                 stream_report_content(sslv->head, sslv->n_head, STREAM_SSL,
-                                      THIS_MODULE, stream_get_name(stream));
+                                      &this_module, stream_get_name(stream));
                 return EPROTO;
             }
         } else if (bootstrap_ca_cert) {
@@ -524,7 +501,6 @@ ssl_close(struct stream *stream)
     ERR_clear_error();
 
     SSL_free(sslv->ssl);
-    clear_handle(sslv->fd, sslv->wevent);
     closesocket(sslv->fd);
     free(sslv);
 }
@@ -716,8 +692,7 @@ ssl_run_wait(struct stream *stream)
     struct ssl_stream *sslv = ssl_stream_cast(stream);
 
     if (sslv->tx_want != SSL_NOTHING) {
-        poll_fd_wait_event(sslv->fd, sslv->wevent,
-                           want_to_poll_events(sslv->tx_want));
+        poll_fd_wait(sslv->fd, want_to_poll_events(sslv->tx_want));
     }
 }
 
@@ -733,14 +708,14 @@ ssl_wait(struct stream *stream, enum stream_wait_type wait)
         } else {
             switch (sslv->state) {
             case STATE_TCP_CONNECTING:
-                poll_fd_wait_event(sslv->fd, sslv->wevent, POLLOUT);
+                poll_fd_wait(sslv->fd, POLLOUT);
                 break;
 
             case STATE_SSL_CONNECTING:
                 /* ssl_connect() called SSL_accept() or SSL_connect(), which
                  * set up the status that we test here. */
-                poll_fd_wait_event(sslv->fd, sslv->wevent,
-                                   want_to_poll_events(SSL_want(sslv->ssl)));
+                poll_fd_wait(sslv->fd,
+                               want_to_poll_events(SSL_want(sslv->ssl)));
                 break;
 
             default:
@@ -751,8 +726,7 @@ ssl_wait(struct stream *stream, enum stream_wait_type wait)
 
     case STREAM_RECV:
         if (sslv->rx_want != SSL_NOTHING) {
-            poll_fd_wait_event(sslv->fd, sslv->wevent,
-                               want_to_poll_events(sslv->rx_want));
+            poll_fd_wait(sslv->fd, want_to_poll_events(sslv->rx_want));
         } else {
             poll_immediate_wake();
         }
@@ -792,7 +766,6 @@ struct pssl_pstream
 {
     struct pstream pstream;
     int fd;
-    HANDLE wevent;
 };
 
 const struct pstream_class pssl_pstream_class;
@@ -821,22 +794,19 @@ pssl_open(const char *name OVS_UNUSED, char *suffix, struct pstream **pstreamp,
         return retval;
     }
 
-    fd = inet_open_passive(SOCK_STREAM, suffix, OFP_OLD_PORT, &ss, dscp);
+    fd = inet_open_passive(SOCK_STREAM, suffix, OFP_PORT, &ss, dscp, true);
     if (fd < 0) {
         return -fd;
     }
 
     port = ss_get_port(&ss);
-    snprintf(bound_name, sizeof bound_name, "ptcp:%"PRIu16":%s",
+    snprintf(bound_name, sizeof bound_name, "pssl:%"PRIu16":%s",
              port, ss_format_address(&ss, addrbuf, sizeof addrbuf));
 
     pssl = xmalloc(sizeof *pssl);
     pstream_init(&pssl->pstream, &pssl_pstream_class, bound_name);
     pstream_set_bound_port(&pssl->pstream, htons(port));
     pssl->fd = fd;
-#ifdef _WIN32
-    pssl->wevent = CreateEvent(NULL, FALSE, FALSE, NULL);
-#endif
     *pstreamp = &pssl->pstream;
     return 0;
 }
@@ -845,7 +815,6 @@ static void
 pssl_close(struct pstream *pstream)
 {
     struct pssl_pstream *pssl = pssl_pstream_cast(pstream);
-    clear_handle(pssl->fd, pssl->wevent);
     closesocket(pssl->fd);
     free(pssl);
 }
@@ -881,7 +850,7 @@ pssl_accept(struct pstream *pstream, struct stream **new_streamp)
         return error;
     }
 
-    snprintf(name, sizeof name, "tcp:%s:%"PRIu16,
+    snprintf(name, sizeof name, "ssl:%s:%"PRIu16,
              ss_format_address(&ss, addrbuf, sizeof addrbuf),
              ss_get_port(&ss));
     return new_ssl_stream(name, new_fd, SERVER, STATE_SSL_CONNECTING,
@@ -892,14 +861,7 @@ static void
 pssl_wait(struct pstream *pstream)
 {
     struct pssl_pstream *pssl = pssl_pstream_cast(pstream);
-    poll_fd_wait_event(pssl->fd, pssl->wevent, POLLIN);
-}
-
-static int
-pssl_set_dscp(struct pstream *pstream, uint8_t dscp)
-{
-    struct pssl_pstream *pssl = pssl_pstream_cast(pstream);
-    return set_dscp(pssl->fd, dscp);
+    poll_fd_wait(pssl->fd, POLLIN);
 }
 
 const struct pstream_class pssl_pstream_class = {
@@ -909,7 +871,6 @@ const struct pstream_class pssl_pstream_class = {
     pssl_close,
     pssl_accept,
     pssl_wait,
-    pssl_set_dscp,
 };
 \f
 /*
@@ -977,9 +938,17 @@ do_ssl_init(void)
         RAND_seed(seed, sizeof seed);
     }
 
-    /* New OpenSSL changed TLSv1_method() to return a "const" pointer, so the
-     * cast is needed to avoid a warning with those newer versions. */
-    method = CONST_CAST(SSL_METHOD *, TLSv1_method());
+    /* OpenSSL has a bunch of "connection methods": SSLv2_method(),
+     * SSLv3_method(), TLSv1_method(), SSLv23_method(), ...  Most of these
+     * support exactly one version of SSL, e.g. TLSv1_method() supports TLSv1
+     * only, not any earlier *or later* version.  The only exception is
+     * SSLv23_method(), which in fact supports *any* version of SSL and TLS.
+     * We don't want SSLv2 or SSLv3 support, so we turn it off below with
+     * SSL_CTX_set_options().
+     *
+     * The cast is needed to avoid a warning with newer versions of OpenSSL in
+     * which SSLv23_method() returns a "const" pointer. */
+    method = CONST_CAST(SSL_METHOD *, SSLv23_method());
     if (method == NULL) {
         VLOG_ERR("TLSv1_method: %s", ERR_error_string(ERR_get_error(), NULL));
         return ENOPROTOOPT;
@@ -996,6 +965,7 @@ do_ssl_init(void)
     SSL_CTX_set_mode(ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
     SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
                        NULL);
+    SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);
 
     return 0;
 }
@@ -1095,7 +1065,7 @@ stream_ssl_set_private_key_file(const char *file_name)
 static void
 stream_ssl_set_certificate_file__(const char *file_name)
 {
-    if (SSL_CTX_use_certificate_chain_file(ctx, file_name) == 1) {
+    if (SSL_CTX_use_certificate_file(ctx, file_name, SSL_FILETYPE_PEM) == 1) {
         certificate.read = true;
     } else {
         VLOG_ERR("SSL_use_certificate_file: %s",
@@ -1184,6 +1154,7 @@ read_cert_file(const char *file_name, X509 ***certs, size_t *n_certs)
             free(*certs);
             *certs = NULL;
             *n_certs = 0;
+            fclose(file);
             return EIO;
         }
 
@@ -1265,8 +1236,6 @@ static void
 stream_ssl_set_ca_cert_file__(const char *file_name,
                               bool bootstrap, bool force)
 {
-    X509 **certs;
-    size_t n_certs;
     struct stat s;
 
     if (!update_ssl_config(&ca_cert, file_name) && !force) {
@@ -1279,33 +1248,26 @@ stream_ssl_set_ca_cert_file__(const char *file_name,
                   "(this is a security risk)");
     } else if (bootstrap && stat(file_name, &s) && errno == ENOENT) {
         bootstrap_ca_cert = true;
-    } else if (!read_cert_file(file_name, &certs, &n_certs)) {
-        size_t i;
-
-        /* Set up list of CAs that the server will accept from the client. */
-        for (i = 0; i < n_certs; i++) {
-            /* SSL_CTX_add_client_CA makes a copy of the relevant data. */
-            if (SSL_CTX_add_client_CA(ctx, certs[i]) != 1) {
-                VLOG_ERR("failed to add client certificate %"PRIuSIZE" from %s: %s",
-                         i, file_name,
+    } else {
+        STACK_OF(X509_NAME) *cert_names = SSL_load_client_CA_file(file_name);
+        if (cert_names) {
+            /* Set up list of CAs that the server will accept from the
+             * client. */
+            SSL_CTX_set_client_CA_list(ctx, cert_names);
+
+            /* Set up CAs for OpenSSL to trust in verifying the peer's
+             * certificate. */
+            SSL_CTX_set_cert_store(ctx, X509_STORE_new());
+            if (SSL_CTX_load_verify_locations(ctx, file_name, NULL) != 1) {
+                VLOG_ERR("SSL_CTX_load_verify_locations: %s",
                          ERR_error_string(ERR_get_error(), NULL));
-            } else {
-                log_ca_cert(file_name, certs[i]);
+                return;
             }
-            X509_free(certs[i]);
-        }
-        free(certs);
-
-        /* Set up CAs for OpenSSL to trust in verifying the peer's
-         * certificate. */
-        SSL_CTX_set_cert_store(ctx, X509_STORE_new());
-        if (SSL_CTX_load_verify_locations(ctx, file_name, NULL) != 1) {
-            VLOG_ERR("SSL_CTX_load_verify_locations: %s",
-                     ERR_error_string(ERR_get_error(), NULL));
-            return;
+            bootstrap_ca_cert = false;
+        } else {
+            VLOG_ERR("failed to load client certificates from %s: %s",
+                     file_name, ERR_error_string(ERR_get_error(), NULL));
         }
-
-        bootstrap_ca_cert = false;
     }
     ca_cert.read = true;
 }
@@ -1415,28 +1377,3 @@ ssl_protocol_cb(int write_p, int version OVS_UNUSED, int content_type,
 
     ds_destroy(&details);
 }
-
-/* In Windows platform, errno is not set for socket calls.
- * The last error has to be gotten from WSAGetLastError(). */
-static int
-sock_errno(void)
-{
-#ifdef _WIN32
-    return WSAGetLastError();
-#else
-    return errno;
-#endif
-}
-
-static void
-clear_handle(int fd OVS_UNUSED, HANDLE wevent OVS_UNUSED)
-{
-#ifdef _WIN32
-    if (fd) {
-        WSAEventSelect(fd, NULL, 0);
-    }
-    if (wevent) {
-        CloseHandle(wevent);
-    }
-#endif
-}