Make unixctl_command_register() idempotent
[cascardo/ovs.git] / lib / unixctl.c
1 /*
2  * Copyright (c) 2008, 2009, 2010, 2011 Nicira Networks.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at:
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <config.h>
18 #include "unixctl.h"
19 #include <assert.h>
20 #include <ctype.h>
21 #include <errno.h>
22 #include <poll.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #include <sys/socket.h>
26 #include <sys/stat.h>
27 #include <unistd.h>
28 #include "coverage.h"
29 #include "dirs.h"
30 #include "dynamic-string.h"
31 #include "fatal-signal.h"
32 #include "list.h"
33 #include "ofpbuf.h"
34 #include "poll-loop.h"
35 #include "shash.h"
36 #include "socket-util.h"
37 #include "svec.h"
38 #include "util.h"
39 #include "vlog.h"
40
41 #ifndef SCM_CREDENTIALS
42 #include <time.h>
43 #endif
44
45 VLOG_DEFINE_THIS_MODULE(unixctl);
46
47 COVERAGE_DEFINE(unixctl_received);
48 COVERAGE_DEFINE(unixctl_replied);
49 \f
50 struct unixctl_command {
51     const char *args;
52     unixctl_cb_func *cb;
53     void *aux;
54 };
55
56 struct unixctl_conn {
57     struct list node;
58     int fd;
59
60     enum { S_RECV, S_PROCESS, S_SEND } state;
61     struct ofpbuf in;
62     struct ds out;
63     size_t out_pos;
64 };
65
66 /* Server for control connection. */
67 struct unixctl_server {
68     char *path;
69     int fd;
70     struct list conns;
71 };
72
73 /* Client for control connection. */
74 struct unixctl_client {
75     char *connect_path;
76     char *bind_path;
77     FILE *stream;
78 };
79
80 static struct vlog_rate_limit rl = VLOG_RATE_LIMIT_INIT(5, 5);
81
82 static struct shash commands = SHASH_INITIALIZER(&commands);
83
84 static void
85 unixctl_help(struct unixctl_conn *conn, const char *args OVS_UNUSED,
86              void *aux OVS_UNUSED)
87 {
88     struct ds ds = DS_EMPTY_INITIALIZER;
89     const struct shash_node **nodes = shash_sort(&commands);
90     size_t i;
91
92     ds_put_cstr(&ds, "The available commands are:\n");
93
94     for (i = 0; i < shash_count(&commands); i++) {
95         const struct shash_node *node = nodes[i];
96         const struct unixctl_command *command = node->data;
97         
98         ds_put_format(&ds, "  %-23s%s\n", node->name, command->args);
99     }
100     free(nodes);
101
102     unixctl_command_reply(conn, 214, ds_cstr(&ds));
103     ds_destroy(&ds);
104 }
105
106 static void
107 unixctl_version(struct unixctl_conn *conn, const char *args OVS_UNUSED,
108                 void *aux OVS_UNUSED)
109 {
110     unixctl_command_reply(conn, 200, get_program_version());
111 }
112
113 void
114 unixctl_command_register(const char *name, const char *args,
115         unixctl_cb_func *cb, void *aux)
116 {
117     struct unixctl_command *command;
118     struct unixctl_command *lookup = shash_find_data(&commands, name);
119
120     assert(!lookup || lookup->cb == cb);
121
122     if (lookup) {
123         return;
124     }
125
126     command = xmalloc(sizeof *command);
127     command->args = args;
128     command->cb = cb;
129     command->aux = aux;
130     shash_add(&commands, name, command);
131 }
132
133 static const char *
134 translate_reply_code(int code)
135 {
136     switch (code) {
137     case 200: return "OK";
138     case 201: return "Created";
139     case 202: return "Accepted";
140     case 204: return "No Content";
141     case 211: return "System Status";
142     case 214: return "Help";
143     case 400: return "Bad Request";
144     case 401: return "Unauthorized";
145     case 403: return "Forbidden";
146     case 404: return "Not Found";
147     case 500: return "Internal Server Error";
148     case 501: return "Invalid Argument";
149     case 503: return "Service Unavailable";
150     default: return "Unknown";
151     }
152 }
153
154 void
155 unixctl_command_reply(struct unixctl_conn *conn,
156                       int code, const char *body)
157 {
158     struct ds *out = &conn->out;
159
160     COVERAGE_INC(unixctl_replied);
161     assert(conn->state == S_PROCESS);
162     conn->state = S_SEND;
163     conn->out_pos = 0;
164
165     ds_clear(out);
166     ds_put_format(out, "%03d %s\n", code, translate_reply_code(code));
167     if (body) {
168         const char *p;
169         for (p = body; *p != '\0'; ) {
170             size_t n = strcspn(p, "\n");
171
172             if (*p == '.') {
173                 ds_put_char(out, '.');
174             }
175             ds_put_buffer(out, p, n);
176             ds_put_char(out, '\n');
177             p += n;
178             if (*p == '\n') {
179                 p++;
180             }
181         }
182     }
183     ds_put_cstr(out, ".\n");
184 }
185
186 /* Creates a unixctl server listening on 'path', which may be:
187  *
188  *      - NULL, in which case <rundir>/<program>.<pid>.ctl is used.
189  *
190  *      - "none", in which case the function will return successfully but
191  *        no socket will actually be created.
192  *
193  *      - A name that does not start with '/', in which case it is put in
194  *        <rundir>.
195  *
196  *      - An absolute path (starting with '/') that gives the exact name of
197  *        the Unix domain socket to listen on.
198  *
199  * A program that (optionally) daemonizes itself should call this function
200  * *after* daemonization, so that the socket name contains the pid of the
201  * daemon instead of the pid of the program that exited.  (Otherwise,
202  * "ovs-appctl --target=<program>" will fail.)
203  *
204  * Returns 0 if successful, otherwise a positive errno value.  If successful,
205  * sets '*serverp' to the new unixctl_server (or to NULL if 'path' was "none"),
206  * otherwise to NULL. */
207 int
208 unixctl_server_create(const char *path, struct unixctl_server **serverp)
209 {
210     struct unixctl_server *server;
211     int error;
212
213     if (path && !strcmp(path, "none")) {
214         *serverp = NULL;
215         return 0;
216     }
217
218     unixctl_command_register("help", "", unixctl_help, NULL);
219     unixctl_command_register("version", "", unixctl_version, NULL);
220
221     server = xmalloc(sizeof *server);
222     list_init(&server->conns);
223
224     if (path) {
225         server->path = abs_file_name(ovs_rundir(), path);
226     } else {
227         server->path = xasprintf("%s/%s.%ld.ctl", ovs_rundir(),
228                                  program_name, (long int) getpid());
229     }
230
231     server->fd = make_unix_socket(SOCK_STREAM, true, false, server->path,
232                                   NULL);
233     if (server->fd < 0) {
234         error = -server->fd;
235         ovs_error(error, "could not initialize control socket %s",
236                   server->path);
237         goto error;
238     }
239
240     if (chmod(server->path, S_IRUSR | S_IWUSR) < 0) {
241         error = errno;
242         ovs_error(error, "failed to chmod control socket %s", server->path);
243         goto error;
244     }
245
246     if (listen(server->fd, 10) < 0) {
247         error = errno;
248         ovs_error(error, "Failed to listen on control socket %s",
249                   server->path);
250         goto error;
251     }
252
253     *serverp = server;
254     return 0;
255
256 error:
257     if (server->fd >= 0) {
258         close(server->fd);
259     }
260     free(server->path);
261     free(server);
262     *serverp = NULL;
263     return error;
264 }
265
266 static void
267 new_connection(struct unixctl_server *server, int fd)
268 {
269     struct unixctl_conn *conn;
270
271     set_nonblocking(fd);
272
273     conn = xmalloc(sizeof *conn);
274     list_push_back(&server->conns, &conn->node);
275     conn->fd = fd;
276     conn->state = S_RECV;
277     ofpbuf_init(&conn->in, 128);
278     ds_init(&conn->out);
279     conn->out_pos = 0;
280 }
281
282 static int
283 run_connection_output(struct unixctl_conn *conn)
284 {
285     while (conn->out_pos < conn->out.length) {
286         size_t bytes_written;
287         int error;
288
289         error = write_fully(conn->fd, conn->out.string + conn->out_pos,
290                             conn->out.length - conn->out_pos, &bytes_written);
291         conn->out_pos += bytes_written;
292         if (error) {
293             return error;
294         }
295     }
296     conn->state = S_RECV;
297     return 0;
298 }
299
300 static void
301 process_command(struct unixctl_conn *conn, char *s)
302 {
303     struct unixctl_command *command;
304     size_t name_len;
305     char *name, *args;
306
307     COVERAGE_INC(unixctl_received);
308     conn->state = S_PROCESS;
309
310     name = s;
311     name_len = strcspn(name, " ");
312     args = name + name_len;
313     args += strspn(args, " ");
314     name[name_len] = '\0';
315
316     command = shash_find_data(&commands, name);
317     if (command) {
318         command->cb(conn, args, command->aux);
319     } else {
320         char *msg = xasprintf("\"%s\" is not a valid command", name);
321         unixctl_command_reply(conn, 400, msg);
322         free(msg);
323     }
324 }
325
326 static int
327 run_connection_input(struct unixctl_conn *conn)
328 {
329     for (;;) {
330         size_t bytes_read;
331         char *newline;
332         int error;
333
334         newline = memchr(conn->in.data, '\n', conn->in.size);
335         if (newline) {
336             char *command = conn->in.data;
337             size_t n = newline - command + 1;
338
339             if (n > 0 && newline[-1] == '\r') {
340                 newline--;
341             }
342             *newline = '\0';
343
344             process_command(conn, command);
345
346             ofpbuf_pull(&conn->in, n);
347             if (!conn->in.size) {
348                 ofpbuf_clear(&conn->in);
349             }
350             return 0;
351         }
352
353         ofpbuf_prealloc_tailroom(&conn->in, 128);
354         error = read_fully(conn->fd, ofpbuf_tail(&conn->in),
355                            ofpbuf_tailroom(&conn->in), &bytes_read);
356         conn->in.size += bytes_read;
357         if (conn->in.size > 65536) {
358             VLOG_WARN_RL(&rl, "excess command length, killing connection");
359             return EPROTO;
360         }
361         if (error) {
362             if (error == EAGAIN || error == EWOULDBLOCK) {
363                 if (!bytes_read) {
364                     return EAGAIN;
365                 }
366             } else {
367                 if (error != EOF || conn->in.size != 0) {
368                     VLOG_WARN_RL(&rl, "read failed: %s",
369                                  (error == EOF
370                                   ? "connection dropped mid-command"
371                                   : strerror(error)));
372                 }
373                 return error;
374             }
375         }
376     }
377 }
378
379 static int
380 run_connection(struct unixctl_conn *conn)
381 {
382     int old_state;
383     do {
384         int error;
385
386         old_state = conn->state;
387         switch (conn->state) {
388         case S_RECV:
389             error = run_connection_input(conn);
390             break;
391
392         case S_PROCESS:
393             error = 0;
394             break;
395
396         case S_SEND:
397             error = run_connection_output(conn);
398             break;
399
400         default:
401             NOT_REACHED();
402         }
403         if (error) {
404             return error;
405         }
406     } while (conn->state != old_state);
407     return 0;
408 }
409
410 static void
411 kill_connection(struct unixctl_conn *conn)
412 {
413     list_remove(&conn->node);
414     ofpbuf_uninit(&conn->in);
415     ds_destroy(&conn->out);
416     close(conn->fd);
417     free(conn);
418 }
419
420 void
421 unixctl_server_run(struct unixctl_server *server)
422 {
423     struct unixctl_conn *conn, *next;
424     int i;
425
426     if (!server) {
427         return;
428     }
429
430     for (i = 0; i < 10; i++) {
431         int fd = accept(server->fd, NULL, NULL);
432         if (fd < 0) {
433             if (errno != EAGAIN && errno != EWOULDBLOCK) {
434                 VLOG_WARN_RL(&rl, "accept failed: %s", strerror(errno));
435             }
436             break;
437         }
438         new_connection(server, fd);
439     }
440
441     LIST_FOR_EACH_SAFE (conn, next, node, &server->conns) {
442         int error = run_connection(conn);
443         if (error && error != EAGAIN) {
444             kill_connection(conn);
445         }
446     }
447 }
448
449 void
450 unixctl_server_wait(struct unixctl_server *server)
451 {
452     struct unixctl_conn *conn;
453
454     if (!server) {
455         return;
456     }
457
458     poll_fd_wait(server->fd, POLLIN);
459     LIST_FOR_EACH (conn, node, &server->conns) {
460         if (conn->state == S_RECV) {
461             poll_fd_wait(conn->fd, POLLIN);
462         } else if (conn->state == S_SEND) {
463             poll_fd_wait(conn->fd, POLLOUT);
464         }
465     }
466 }
467
468 /* Destroys 'server' and stops listening for connections. */
469 void
470 unixctl_server_destroy(struct unixctl_server *server)
471 {
472     if (server) {
473         struct unixctl_conn *conn, *next;
474
475         LIST_FOR_EACH_SAFE (conn, next, node, &server->conns) {
476             kill_connection(conn);
477         }
478
479         close(server->fd);
480         fatal_signal_unlink_file_now(server->path);
481         free(server->path);
482         free(server);
483     }
484 }
485 \f
486 /* Connects to a Vlog server socket.  'path' should be the name of a Vlog
487  * server socket.  If it does not start with '/', it will be prefixed with
488  * the rundir (e.g. /usr/local/var/run/openvswitch).
489  *
490  * Returns 0 if successful, otherwise a positive errno value.  If successful,
491  * sets '*clientp' to the new unixctl_client, otherwise to NULL. */
492 int
493 unixctl_client_create(const char *path, struct unixctl_client **clientp)
494 {
495     static int counter;
496     struct unixctl_client *client;
497     int error;
498     int fd = -1;
499
500     /* Determine location. */
501     client = xmalloc(sizeof *client);
502     client->connect_path = abs_file_name(ovs_rundir(), path);
503     client->bind_path = xasprintf("/tmp/vlog.%ld.%d",
504                                   (long int) getpid(), counter++);
505
506     /* Open socket. */
507     fd = make_unix_socket(SOCK_STREAM, false, false,
508                           client->bind_path, client->connect_path);
509     if (fd < 0) {
510         error = -fd;
511         goto error;
512     }
513
514     /* Bind socket to stream. */
515     client->stream = fdopen(fd, "r+");
516     if (!client->stream) {
517         error = errno;
518         VLOG_WARN("%s: fdopen failed (%s)",
519                   client->connect_path, strerror(error));
520         goto error;
521     }
522     *clientp = client;
523     return 0;
524
525 error:
526     if (fd >= 0) {
527         close(fd);
528     }
529     free(client->connect_path);
530     free(client->bind_path);
531     free(client);
532     *clientp = NULL;
533     return error;
534 }
535
536 /* Destroys 'client'. */
537 void
538 unixctl_client_destroy(struct unixctl_client *client)
539 {
540     if (client) {
541         fatal_signal_unlink_file_now(client->bind_path);
542         free(client->bind_path);
543         free(client->connect_path);
544         fclose(client->stream);
545         free(client);
546     }
547 }
548
549 /* Sends 'request' to the server socket and waits for a reply.  Returns 0 if
550  * successful, otherwise to a positive errno value.  If successful, sets
551  * '*reply' to the reply, which the caller must free, otherwise to NULL. */
552 int
553 unixctl_client_transact(struct unixctl_client *client,
554                         const char *request,
555                         int *reply_code, char **reply_body)
556 {
557     struct ds line = DS_EMPTY_INITIALIZER;
558     struct ds reply = DS_EMPTY_INITIALIZER;
559     int error;
560
561     /* Send 'request' to server.  Add a new-line if 'request' didn't end in
562      * one. */
563     fputs(request, client->stream);
564     if (request[0] == '\0' || request[strlen(request) - 1] != '\n') {
565         putc('\n', client->stream);
566     }
567     if (ferror(client->stream)) {
568         VLOG_WARN("error sending request to %s: %s",
569                   client->connect_path, strerror(errno));
570         return errno;
571     }
572
573     /* Wait for response. */
574     *reply_code = -1;
575     for (;;) {
576         const char *s;
577
578         error = ds_get_line(&line, client->stream);
579         if (error) {
580             VLOG_WARN("error reading reply from %s: %s",
581                       client->connect_path,
582                       ovs_retval_to_string(error));
583             goto error;
584         }
585
586         s = ds_cstr(&line);
587         if (*reply_code == -1) {
588             if (!isdigit((unsigned char)s[0])
589                     || !isdigit((unsigned char)s[1])
590                     || !isdigit((unsigned char)s[2])) {
591                 VLOG_WARN("reply from %s does not start with 3-digit code",
592                           client->connect_path);
593                 error = EPROTO;
594                 goto error;
595             }
596             sscanf(s, "%3d", reply_code);
597         } else {
598             if (s[0] == '.') {
599                 if (s[1] == '\0') {
600                     break;
601                 }
602                 s++;
603             }
604             ds_put_cstr(&reply, s);
605             ds_put_char(&reply, '\n');
606         }
607     }
608     *reply_body = ds_cstr(&reply);
609     ds_destroy(&line);
610     return 0;
611
612 error:
613     ds_destroy(&line);
614     ds_destroy(&reply);
615     *reply_code = 0;
616     *reply_body = NULL;
617     return error == EOF ? EPROTO : error;
618 }
619
620 /* Returns the path of the server socket to which 'client' is connected.  The
621  * caller must not modify or free the returned string. */
622 const char *
623 unixctl_client_target(const struct unixctl_client *client)
624 {
625     return client->connect_path;
626 }