fix this too
[cascardo/rnetproxy.git] / rnetserver.c
1 /*
2  *  Copyright (C) 2011  Thadeu Lima de Souza Cascardo <cascardo@holoscopio.com>
3  *
4  *  This program is free software; you can redistribute it and/or modify
5  *  it under the terms of the GNU General Public License as published by
6  *  the Free Software Foundation; either version 2 of the License, or
7  *  (at your option) any later version.
8  *
9  *  This program is distributed in the hope that it will be useful,
10  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *  GNU General Public License for more details.
13  *
14  *  You should have received a copy of the GNU General Public License along
15  *  with this program; if not, write to the Free Software Foundation, Inc.,
16  *  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
17  */
18
19 #include <string.h>
20 #include <stdlib.h>
21 #include <stdio.h>
22 #include <unistd.h>
23 #include <sys/socket.h>
24 #include <netinet/in.h>
25 #include <arpa/inet.h>
26 #include <gnutls/gnutls.h>
27 #include <zlib.h>
28 #include <resp.c>
29
30 #define DH_BITS 1024
31 static void * get_creds(char *certfile, char *keyfile)
32 {
33         static gnutls_certificate_credentials_t cred;
34         gnutls_dh_params_t dh_params;
35         gnutls_dh_params_init(&dh_params);
36         gnutls_dh_params_generate2(dh_params, DH_BITS);
37         gnutls_certificate_allocate_credentials(&cred);
38         gnutls_certificate_set_x509_key_file(cred, certfile, keyfile,
39                                         GNUTLS_X509_FMT_PEM);
40         gnutls_certificate_set_dh_params(cred, dh_params);
41         return cred;
42 }
43
44 static void session_new(gnutls_session_t *session)
45 {
46         static void *cred;
47         cred = get_creds("cert.pem", "key.pem");
48         gnutls_init(session, GNUTLS_SERVER);
49         gnutls_set_default_priority(*session);
50         gnutls_credentials_set(*session, GNUTLS_CRD_CERTIFICATE, cred);
51         gnutls_dh_set_prime_bits(*session, DH_BITS);
52 }
53 #undef DH_BITS
54
55 static int buildRecord(char *buffer, size_t len, char **out, size_t *olen)
56 {
57         *olen = len + 4;
58         *out = malloc(*olen);
59         (*out)[0] = 0x0;
60         (*out)[1] = (len >> 8);
61         (*out)[2] = (len & 0xff);
62         (*out)[3] = 0x1;
63         memcpy(*out + 4, buffer, len);
64         return 0;
65 }
66
67 static int deflateRecord(char *buffer, size_t len, char **out, size_t *olen)
68 {
69         z_stream zstrm;
70         int r;
71         zstrm.zalloc = Z_NULL;
72         zstrm.zfree = Z_NULL;
73         zstrm.opaque = Z_NULL;
74         if ((r = deflateInit(&zstrm, Z_DEFAULT_COMPRESSION)) != Z_OK)
75                 return -1;
76         *out = malloc(len * 2 + 36);
77         if (!out) {
78                 deflateEnd(&zstrm);
79                 return -1;
80         }
81         zstrm.next_in = buffer;
82         zstrm.avail_in = len;
83         zstrm.next_out = *out + 6;
84         zstrm.avail_out = len * 2 + 30;
85         while ((r = deflate(&zstrm, Z_FINISH)) != Z_STREAM_END &&
86                 zstrm.avail_out > 0);
87         if ((r = deflate(&zstrm, Z_FINISH)) != Z_STREAM_END) {
88                 deflateEnd(&zstrm);
89                 free(*out);
90                 return -1;
91         }
92         *olen = zstrm.avail_out + 6;
93         (*out)[0] = 0x1;
94         (*out)[1] = (zstrm.avail_out >> 8);
95         (*out)[2] = (zstrm.avail_out & 0xff);
96         (*out)[3] = (len >> 8);
97         (*out)[4] = (len & 0xff);
98         (*out)[5] = 0x1;
99         deflateEnd(&zstrm);
100         return 0;
101 }
102
103
104 static char *response;
105 static char *def_response;
106
107 int main(int argc, char **argv)
108 {
109         int s;
110         struct sockaddr_in saddr;
111         int c;
112         int r;
113         char buffer[8192];
114         int resp_size;
115         int def_resp_size;
116         int count = 0;
117         int val = 1;
118         gnutls_session_t session;
119         gnutls_global_init();
120         session_new(&session);
121         s = socket(PF_INET, SOCK_STREAM, 0);
122         setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
123         saddr.sin_family = AF_INET;
124         saddr.sin_port = htons(3456);
125         saddr.sin_addr.s_addr = htonl(INADDR_ANY);
126         bind(s, (struct sockaddr *) &saddr, sizeof(saddr));
127         listen(s, 5);
128         c = accept(s, NULL, NULL);
129         close(s);
130         gnutls_transport_set_ptr(session, (gnutls_transport_ptr_t) c);
131         r = read(c, buffer, 1);
132         if (r == 1 && buffer[0] == 1)
133                 write(c, "E", 1);
134         r = read(c, buffer, 14);
135         if (r == 14 && !memcmp(buffer, "00000000000000", 14))
136                 write(c, "08082013225300", 14);
137         if ((r = gnutls_handshake(session)) < 0)
138                 fprintf(stderr, "error in handshake: %s\n",
139                                 gnutls_strerror(r));
140         else
141                 fprintf(stderr, "handshake ok\n");
142         while ((r = gnutls_record_recv(session, buffer, sizeof(buffer))) > 0) {
143                 count++;
144                 if (count >= 2)
145                         write(1, buffer, r);
146                 if (count > 1) {
147                         resp_size = resp(&response);
148                         buildRecord(response, resp_size, &def_response, &def_resp_size);
149                         gnutls_record_send(session, def_response, def_resp_size);
150                 }
151         }
152         close(c);
153         gnutls_global_deinit();
154         return 0;
155 }