If receive returns an error report it using CLOSE event.
[cascardo/rnetproxy.git] / hcconn_ssl.c
1 /*
2 ** Copyright (C) 2006 Thadeu Lima de Souza Cascardo <cascardo@minaslivre.org>
3 ** Copyright (C) 2009 Thadeu Lima de Souza Cascardo <cascardo@holoscopio.com>
4 **  
5 ** This program is free software; you can redistribute it and/or modify
6 ** it under the terms of the GNU General Public License as published by
7 ** the Free Software Foundation; either version 2 of the License, or
8 ** (at your option) any later version.
9 **  
10 ** This program is distributed in the hope that it will be useful,
11 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
12 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 ** GNU General Public License for more details.
14 **  
15 ** You should have received a copy of the GNU General Public License
16 ** along with this program; if not, write to the Free Software
17 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
18 **  
19 */
20
21 #include <gnutls/gnutls.h>
22 #include <glib.h>
23 #include <string.h>
24 #include <errno.h>
25 #include <fcntl.h>
26 #include "hcconn_internal.h"
27
28 struct ssl_data
29 {
30   gnutls_session_t session;
31   GString *buffer;
32   gboolean handshaking;
33   gpointer lowconn;
34 };
35
36 #define DH_BITS 1024
37 void *
38 hc_conn_ssl_server_init_credentials (char *certfile, char *keyfile)
39 {
40   static int initialized = 0;
41   static gnutls_certificate_credentials_t cred;
42   gnutls_dh_params_t dh_params;
43   if (initialized)
44     return cred;
45   gnutls_dh_params_init (&dh_params);
46   gnutls_dh_params_generate2 (dh_params, DH_BITS);
47   gnutls_certificate_allocate_credentials (&cred);
48   gnutls_certificate_set_x509_key_file (cred, certfile, keyfile,
49                                         GNUTLS_X509_FMT_PEM);
50   gnutls_certificate_set_dh_params (cred, dh_params);
51   initialized = 1;
52   return cred;
53 }
54
55 static void *
56 ssl_server_get_credentials(void)
57 {
58   return hc_conn_ssl_server_init_credentials (NULL, NULL);
59 }
60  
61 static void
62 ssl_server_session_new (gnutls_session_t *session)
63 {
64   static void *cred;
65   cred = ssl_server_get_credentials ();
66   gnutls_init (session, GNUTLS_SERVER);
67   gnutls_priority_set_direct (*session, "NORMAL", NULL);
68   gnutls_credentials_set (*session, GNUTLS_CRD_CERTIFICATE, cred);
69   gnutls_dh_set_prime_bits (*session, DH_BITS);
70 }
71 #undef DH_BITS
72
73 static void
74 ssl_client_session_new (gnutls_session_t *session)
75 {
76   int kx_prio[] = {GNUTLS_KX_RSA, 0};
77   gnutls_certificate_credentials cred;
78   gnutls_certificate_allocate_credentials (&cred);
79   gnutls_init (session, GNUTLS_CLIENT);
80   gnutls_set_default_priority (*session);
81   gnutls_kx_set_priority (*session, kx_prio);
82   gnutls_credentials_set (*session, GNUTLS_CRD_CERTIFICATE, cred);
83 }
84
85 static struct ssl_data *
86 ssl_data_new (int server)
87 {
88   struct ssl_data *ssl;
89   ssl = g_slice_new (struct ssl_data);
90   if (server)
91     ssl_server_session_new (&ssl->session);
92   else
93     ssl_client_session_new (&ssl->session);
94   ssl->buffer = g_string_sized_new (4096);
95   ssl->handshaking = FALSE;
96   return ssl;
97 }
98
99 static void
100 ssl_data_destroy (struct ssl_data *ssl)
101 {
102   gnutls_deinit (ssl->session);
103   g_string_free (ssl->buffer, TRUE);
104   g_slice_free (struct ssl_data, ssl);
105 }
106
107 static ssize_t
108 ssl_push (gnutls_transport_ptr_t ptr, const void *buffer, size_t len)
109 {
110   HCConn *conn = ptr;
111   struct ssl_data *ssl = conn->layer;
112   hc_conn_write (ssl->lowconn, (void *) buffer, len);
113   return len;
114 }
115
116 static ssize_t
117 ssl_pull (gnutls_transport_ptr_t ptr, void *buffer, size_t len)
118 {
119   HCConn *conn = ptr;
120   struct ssl_data *ssl = conn->layer;
121   int r;
122   if (ssl->handshaking == TRUE)
123     {
124       r = hc_conn_read (ssl->lowconn, buffer, len);
125       return r;
126     }
127   if (len > ssl->buffer->len)
128     {
129       r = ssl->buffer->len;
130       memcpy (buffer, ssl->buffer->str, r);
131       g_string_truncate (ssl->buffer, 0);
132     }
133   else
134     {
135       r = len;
136       memcpy (buffer, ssl->buffer->str, r);
137       g_string_erase (ssl->buffer, 0, r);
138     }
139   if (r == 0)
140     {
141       gnutls_transport_set_errno (ssl->session, EAGAIN);
142       return -1;
143     }
144   return r;
145 }
146
147 static void
148 ssl_server_handshake (struct ssl_data *ssl)
149 {
150   int error;
151   if ((error = gnutls_handshake (ssl->session)) < 0)
152     {
153       if (gnutls_error_is_fatal (error))
154         g_critical ("Fatal error while doing TLS handshaking: %s\n",
155                     gnutls_strerror (error));
156     }
157   else
158     {
159       ssl->handshaking = FALSE;
160     }
161 }
162
163 static void
164 ssl_server_connect (HCConn *conn)
165 {
166   struct ssl_data *ssl = conn->layer;
167   gnutls_transport_set_ptr (ssl->session, (gnutls_transport_ptr_t) conn);
168   gnutls_transport_set_push_function (ssl->session, ssl_push);
169   gnutls_transport_set_pull_function (ssl->session, ssl_pull);
170   ssl->handshaking = TRUE;
171   ssl_server_handshake (ssl);
172 }
173
174 static void
175 hc_conn_ssl_close (gpointer data)
176 {
177   struct ssl_data *ssl = data;
178   if (ssl != NULL)
179     {
180       gnutls_bye (ssl->session, GNUTLS_SHUT_RDWR);
181       hc_conn_close (ssl->lowconn);
182       ssl_data_destroy (ssl);
183     }
184 }
185
186 static ssize_t
187 hc_conn_ssl_read (gpointer data, gchar *buffer, size_t len)
188 {
189   struct ssl_data *ssl = data;
190   return gnutls_record_recv (ssl->session, buffer, len);
191 }
192
193 static ssize_t
194 hc_conn_ssl_write (gpointer data, gchar *buffer, size_t len)
195 {
196   struct ssl_data *ssl = data;
197   return gnutls_record_send (ssl->session, buffer, len);
198 }
199
200 void
201 hc_conn_ssl_watch (HCConn *conn, HCEvent event, gpointer data)
202 {
203   char buffer[4096];
204   HCConn *ssl_conn = data;
205   struct ssl_data *ssl = ssl_conn->layer;
206   int r;
207   switch (event)
208     {
209     case HC_EVENT_READ:
210       if (ssl->handshaking)
211         {
212           ssl_server_handshake (ssl);
213           return;
214         }
215       while ((r = hc_conn_read (ssl->lowconn, buffer, sizeof (buffer))) > 0)
216         g_string_append_len (ssl->buffer, buffer, r);
217       if (ssl_conn->func && !ssl->handshaking)
218         ssl_conn->func (ssl_conn, event, ssl_conn->data);
219       break;
220     case HC_EVENT_CLOSE:
221       if (ssl_conn->func)
222         ssl_conn->func (ssl_conn, event, ssl_conn->data);
223     }
224 }
225
226 static int
227 hc_conn_set_driver_ssl (HCConn *conn, HCConn *lowconn, int server)
228 {
229   struct ssl_data *ssl;
230   ssl = ssl_data_new (server);
231   if (ssl == NULL)
232     return -1;
233   ssl->lowconn = lowconn;
234   conn->layer = ssl;
235   conn->read = hc_conn_ssl_read;
236   conn->write = hc_conn_ssl_write;
237   conn->close = hc_conn_ssl_close;
238   hc_conn_set_callback (lowconn, hc_conn_ssl_watch, conn);
239   ssl_server_connect (conn);
240   return 0;
241 }
242
243 int
244 hc_conn_set_driver_ssl_client (HCConn *conn, HCConn *lowconn)
245 {
246   return hc_conn_set_driver_ssl (conn, lowconn, 0);
247 }
248
249 int
250 hc_conn_set_driver_ssl_server (HCConn *conn, HCConn *lowconn)
251 {
252   return hc_conn_set_driver_ssl (conn, lowconn, 1);
253 }