Added separate header for SSL connection support.
[cascardo/rnetproxy.git] / hcconn_ssl.c
1 /*
2 ** Copyright (C) 2006 Thadeu Lima de Souza Cascardo <cascardo@minaslivre.org>
3 ** Copyright (C) 2009 Thadeu Lima de Souza Cascardo <cascardo@holoscopio.com>
4 **  
5 ** This program is free software; you can redistribute it and/or modify
6 ** it under the terms of the GNU General Public License as published by
7 ** the Free Software Foundation; either version 2 of the License, or
8 ** (at your option) any later version.
9 **  
10 ** This program is distributed in the hope that it will be useful,
11 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
12 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 ** GNU General Public License for more details.
14 **  
15 ** You should have received a copy of the GNU General Public License
16 ** along with this program; if not, write to the Free Software
17 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
18 **  
19 */
20
21 #include <gnutls/gnutls.h>
22 #include <glib.h>
23 #include <string.h>
24 #include <errno.h>
25 #include <fcntl.h>
26 #include "hcconn_internal.h"
27
28 struct ssl_data
29 {
30   gnutls_session_t session;
31   GString *buffer;
32   gboolean handshaking;
33   gpointer lowconn;
34 };
35
36 #define DH_BITS 1024
37 static gnutls_anon_server_credentials_t
38 ssl_server_get_credentials (void)
39 {
40   static int initialized = 0;
41   static gnutls_anon_server_credentials_t cred;
42   gnutls_dh_params_t dh_params;
43   if (initialized)
44     return cred;
45   gnutls_dh_params_init (&dh_params);
46   gnutls_dh_params_generate2 (dh_params, DH_BITS);
47   gnutls_anon_allocate_server_credentials (&cred);
48   gnutls_anon_set_server_dh_params (cred, dh_params);
49   initialized = 1;
50   return cred;
51 }
52
53 static void
54 ssl_server_session_new (gnutls_session_t *session)
55 {
56   static gnutls_anon_server_credentials_t cred;
57   cred = ssl_server_get_credentials ();
58   gnutls_init (session, GNUTLS_SERVER);
59   gnutls_priority_set_direct (*session, "NORMAL:+ANON-DH", NULL);
60   gnutls_credentials_set (*session, GNUTLS_CRD_ANON, cred);
61   gnutls_dh_set_prime_bits (*session, DH_BITS);
62 }
63 #undef DH_BITS
64
65 static void
66 ssl_client_session_new (gnutls_session_t *session)
67 {
68   int kx_prio[] = {GNUTLS_KX_RSA, 0};
69   gnutls_certificate_credentials cred;
70   gnutls_certificate_allocate_credentials (&cred);
71   gnutls_init (session, GNUTLS_CLIENT);
72   gnutls_set_default_priority (*session);
73   gnutls_kx_set_priority (*session, kx_prio);
74   gnutls_credentials_set (*session, GNUTLS_CRD_CERTIFICATE, cred);
75 }
76
77 static struct ssl_data *
78 ssl_data_new (int server)
79 {
80   struct ssl_data *ssl;
81   ssl = g_slice_new (struct ssl_data);
82   if (server)
83     ssl_server_session_new (&ssl->session);
84   else
85     ssl_client_session_new (&ssl->session);
86   ssl->buffer = g_string_sized_new (4096);
87   ssl->handshaking = FALSE;
88   return ssl;
89 }
90
91 static void
92 ssl_data_destroy (struct ssl_data *ssl)
93 {
94   gnutls_deinit (ssl->session);
95   g_string_free (ssl->buffer, TRUE);
96   g_slice_free (struct ssl_data, ssl);
97 }
98
99 static ssize_t
100 ssl_push (gnutls_transport_ptr_t ptr, const void *buffer, size_t len)
101 {
102   HCConn *conn = ptr;
103   struct ssl_data *ssl = conn->layer;
104   hc_conn_write (ssl->lowconn, (void *) buffer, len);
105   return len;
106 }
107
108 static ssize_t
109 ssl_pull (gnutls_transport_ptr_t ptr, void *buffer, size_t len)
110 {
111   HCConn *conn = ptr;
112   struct ssl_data *ssl = conn->layer;
113   int r;
114   if (ssl->handshaking == TRUE)
115     {
116       r = hc_conn_read (ssl->lowconn, buffer, len);
117       return r;
118     }
119   if (len > ssl->buffer->len)
120     {
121       r = ssl->buffer->len;
122       memcpy (buffer, ssl->buffer->str, r);
123       g_string_truncate (ssl->buffer, 0);
124     }
125   else
126     {
127       r = len;
128       memcpy (buffer, ssl->buffer->str, r);
129       g_string_erase (ssl->buffer, 0, r);
130     }
131   if (r == 0)
132     {
133       gnutls_transport_set_errno (ssl->session, EAGAIN);
134       return -1;
135     }
136   return r;
137 }
138
139 static void
140 ssl_server_handshake (struct ssl_data *ssl)
141 {
142   int error;
143   if ((error = gnutls_handshake (ssl->session)) < 0)
144     {
145       if (gnutls_error_is_fatal (error))
146         g_critical ("Fatal error while doing TLS handshaking: %s\n",
147                     gnutls_strerror (error));
148     }
149   else
150     {
151       ssl->handshaking = FALSE;
152     }
153 }
154
155 static void
156 ssl_server_connect (HCConn *conn)
157 {
158   struct ssl_data *ssl = conn->layer;
159   gnutls_transport_set_ptr (ssl->session, (gnutls_transport_ptr_t) conn);
160   gnutls_transport_set_push_function (ssl->session, ssl_push);
161   gnutls_transport_set_pull_function (ssl->session, ssl_pull);
162   ssl->handshaking = TRUE;
163   ssl_server_handshake (ssl);
164 }
165
166 static void
167 hc_conn_ssl_close (gpointer data)
168 {
169   struct ssl_data *ssl = data;
170   if (ssl != NULL)
171     {
172       gnutls_bye (ssl->session, GNUTLS_SHUT_RDWR);
173       hc_conn_close (ssl->lowconn);
174       ssl_data_destroy (ssl);
175     }
176 }
177
178 static ssize_t
179 hc_conn_ssl_read (gpointer data, gchar *buffer, size_t len)
180 {
181   struct ssl_data *ssl = data;
182   return gnutls_record_recv (ssl->session, buffer, len);
183 }
184
185 static ssize_t
186 hc_conn_ssl_write (gpointer data, gchar *buffer, size_t len)
187 {
188   struct ssl_data *ssl = data;
189   return gnutls_record_send (ssl->session, buffer, len);
190 }
191
192 void
193 hc_conn_ssl_watch (HCConn *conn, HCEvent event, gpointer data)
194 {
195   char buffer[4096];
196   HCConn *ssl_conn = data;
197   struct ssl_data *ssl = ssl_conn->layer;
198   int r;
199   switch (event)
200     {
201     case HC_EVENT_READ:
202       if (ssl->handshaking)
203         {
204           ssl_server_handshake (ssl);
205           return;
206         }
207       while ((r = hc_conn_read (ssl->lowconn, buffer, sizeof (buffer))) > 0)
208         g_string_append_len (ssl->buffer, buffer, r);
209       if (ssl_conn->func && !ssl->handshaking)
210         ssl_conn->func (ssl_conn, event, ssl_conn->data);
211       break;
212     case HC_EVENT_CLOSE:
213       if (ssl_conn->func)
214         ssl_conn->func (ssl_conn, event, ssl_conn->data);
215     }
216 }
217
218 static int
219 hc_conn_set_driver_ssl (HCConn *conn, HCConn *lowconn, int server)
220 {
221   struct ssl_data *ssl;
222   ssl = ssl_data_new (server);
223   if (ssl == NULL)
224     return -1;
225   ssl->lowconn = lowconn;
226   conn->layer = ssl;
227   conn->read = hc_conn_ssl_read;
228   conn->write = hc_conn_ssl_write;
229   conn->close = hc_conn_ssl_close;
230   hc_conn_set_callback (lowconn, hc_conn_ssl_watch, conn);
231   ssl_server_connect (conn);
232   return 0;
233 }
234
235 int
236 hc_conn_set_driver_ssl_client (HCConn *conn, HCConn *lowconn)
237 {
238   return hc_conn_set_driver_ssl (conn, lowconn, 0);
239 }
240
241 int
242 hc_conn_set_driver_ssl_server (HCConn *conn, HCConn *lowconn)
243 {
244   return hc_conn_set_driver_ssl (conn, lowconn, 1);
245 }