Check for driver methods before calling them and reset them on close.
[cascardo/rnetproxy.git] / hcconn_ssl.c
1 /*
2 ** Copyright (C) 2006 Thadeu Lima de Souza Cascardo <cascardo@minaslivre.org>
3 ** Copyright (C) 2009 Thadeu Lima de Souza Cascardo <cascardo@holoscopio.com>
4 **  
5 ** This program is free software; you can redistribute it and/or modify
6 ** it under the terms of the GNU General Public License as published by
7 ** the Free Software Foundation; either version 2 of the License, or
8 ** (at your option) any later version.
9 **  
10 ** This program is distributed in the hope that it will be useful,
11 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
12 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 ** GNU General Public License for more details.
14 **  
15 ** You should have received a copy of the GNU General Public License
16 ** along with this program; if not, write to the Free Software
17 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
18 **  
19 */
20
21 #include <gnutls/gnutls.h>
22 #include <glib.h>
23 #include <string.h>
24 #include <errno.h>
25 #include <fcntl.h>
26 #include "hcconn_internal.h"
27
28 struct ssl_data
29 {
30   gnutls_session_t session;
31   GString *buffer;
32   gboolean handshaking;
33   gpointer lowconn;
34 };
35
36 static struct ssl_data *
37 ssl_data_new (void)
38 {
39   struct ssl_data *ssl;
40   int kx_prio[] = {GNUTLS_KX_RSA, 0};
41   gnutls_certificate_credentials cred;
42   gnutls_certificate_allocate_credentials (&cred);
43   ssl = g_slice_new (struct ssl_data);
44   gnutls_init (&ssl->session, GNUTLS_CLIENT);
45   gnutls_set_default_priority (ssl->session);
46   gnutls_kx_set_priority (ssl->session, kx_prio);
47   gnutls_credentials_set (ssl->session, GNUTLS_CRD_CERTIFICATE, cred);
48   ssl->buffer = g_string_sized_new (4096);
49   ssl->handshaking = FALSE;
50   return ssl;
51 }
52
53 static void
54 ssl_data_destroy (struct ssl_data *ssl)
55 {
56   gnutls_deinit (ssl->session);
57   g_string_free (ssl->buffer, TRUE);
58   g_slice_free (struct ssl_data, ssl);
59 }
60
61 static ssize_t
62 ssl_push (gnutls_transport_ptr_t ptr, const void *buffer, size_t len)
63 {
64   HCConn *conn = ptr;
65   struct ssl_data *ssl = conn->layer;
66   hc_conn_write (ssl->lowconn, (void *) buffer, len);
67   return len;
68 }
69
70 static ssize_t
71 ssl_pull (gnutls_transport_ptr_t ptr, void *buffer, size_t len)
72 {
73   HCConn *conn = ptr;
74   struct ssl_data *ssl = conn->layer;
75   int r;
76   if (ssl->handshaking == TRUE)
77     {
78       r = hc_conn_read (ssl->lowconn, buffer, len);
79       return r;
80     }
81   if (len > ssl->buffer->len)
82     {
83       r = ssl->buffer->len;
84       memcpy (buffer, ssl->buffer->str, r);
85       g_string_truncate (ssl->buffer, 0);
86     }
87   else
88     {
89       r = len;
90       memcpy (buffer, ssl->buffer->str, r);
91       g_string_erase (ssl->buffer, 0, r);
92     }
93   if (r == 0)
94     {
95       gnutls_transport_set_errno (ssl->session, EAGAIN);
96       return -1;
97     }
98   return r;
99 }
100
101 static void
102 ssl_server_handshake (struct ssl_data *ssl)
103 {
104   int error;
105   if ((error = gnutls_handshake (ssl->session)) < 0)
106     {
107       if (gnutls_error_is_fatal (error))
108         g_critical ("Fatal error while doing TLS handshaking: %s\n",
109                     gnutls_strerror (error));
110     }
111   else
112     {
113       ssl->handshaking = FALSE;
114     }
115 }
116
117 static void
118 ssl_server_connect (HCConn *conn)
119 {
120   struct ssl_data *ssl = conn->layer;
121   gnutls_transport_set_ptr (ssl->session, (gnutls_transport_ptr_t) conn);
122   gnutls_transport_set_push_function (ssl->session, ssl_push);
123   gnutls_transport_set_pull_function (ssl->session, ssl_pull);
124   ssl->handshaking = TRUE;
125   ssl_server_handshake (ssl);
126 }
127
128 static void
129 hc_conn_ssl_close (gpointer data)
130 {
131   struct ssl_data *ssl = data;
132   if (ssl != NULL)
133     {
134       gnutls_bye (ssl->session, GNUTLS_SHUT_RDWR);
135       hc_conn_close (ssl->lowconn);
136       ssl_data_destroy (ssl);
137     }
138 }
139
140 static ssize_t
141 hc_conn_ssl_read (gpointer data, gchar *buffer, size_t len)
142 {
143   struct ssl_data *ssl = data;
144   return gnutls_record_recv (ssl->session, buffer, len);
145 }
146
147 static ssize_t
148 hc_conn_ssl_write (gpointer data, gchar *buffer, size_t len)
149 {
150   struct ssl_data *ssl = data;
151   return gnutls_record_send (ssl->session, buffer, len);
152 }
153
154 void
155 hc_conn_ssl_watch (HCConn *conn, HCEvent event, gpointer data)
156 {
157   char buffer[4096];
158   HCConn *ssl_conn = data;
159   struct ssl_data *ssl = ssl_conn->layer;
160   int r;
161   if (event != HC_EVENT_READ)
162     return;
163   if (ssl->handshaking)
164     {
165       ssl_server_handshake (ssl);
166       return;
167     }
168   while ((r = hc_conn_read (ssl->lowconn, buffer, sizeof (buffer))) > 0)
169     g_string_append_len (ssl->buffer, buffer, r);
170   if (ssl_conn->func && !ssl->handshaking)
171     ssl_conn->func (ssl_conn, event, ssl_conn->data);
172   return;
173 }
174
175 void
176 hc_conn_set_driver_ssl (HCConn *conn, HCConn *lowconn)
177 {
178   struct ssl_data *ssl = ssl_data_new ();
179   ssl->lowconn = lowconn;
180   conn->layer = ssl;
181   conn->read = hc_conn_ssl_read;
182   conn->write = hc_conn_ssl_write;
183   conn->close = hc_conn_ssl_close;
184   hc_conn_set_callback (lowconn, hc_conn_ssl_watch, conn);
185   ssl_server_connect (conn);
186 }