c750fe9200305d894e6507165fc2d48b5cb8d383
[cascardo/ovs.git] / python / ovs / unixctl / server.py
1 # Copyright (c) 2012 Nicira, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import copy
16 import errno
17 import os
18 import types
19
20 import ovs.dirs
21 import ovs.jsonrpc
22 import ovs.stream
23 import ovs.unixctl
24 import ovs.util
25 import ovs.version
26 import ovs.vlog
27
28 Message = ovs.jsonrpc.Message
29 vlog = ovs.vlog.Vlog("unixctl_server")
30 strtypes = types.StringTypes
31
32
33 class UnixctlConnection(object):
34     def __init__(self, rpc):
35         assert isinstance(rpc, ovs.jsonrpc.Connection)
36         self._rpc = rpc
37         self._request_id = None
38
39     def run(self):
40         self._rpc.run()
41         error = self._rpc.get_status()
42         if error or self._rpc.get_backlog():
43             return error
44
45         for _ in range(10):
46             if error or self._request_id:
47                 break
48
49             error, msg = self._rpc.recv()
50             if msg:
51                 if msg.type == Message.T_REQUEST:
52                     self._process_command(msg)
53                 else:
54                     # XXX: rate-limit
55                     vlog.warn("%s: received unexpected %s message"
56                               % (self._rpc.name,
57                                  Message.type_to_string(msg.type)))
58                     error = errno.EINVAL
59
60             if not error:
61                 error = self._rpc.get_status()
62
63         return error
64
65     def reply(self, body):
66         self._reply_impl(True, body)
67
68     def reply_error(self, body):
69         self._reply_impl(False, body)
70
71     # Called only by unixctl classes.
72     def _close(self):
73         self._rpc.close()
74         self._request_id = None
75
76     def _wait(self, poller):
77         self._rpc.wait(poller)
78         if not self._rpc.get_backlog():
79             self._rpc.recv_wait(poller)
80
81     def _reply_impl(self, success, body):
82         assert isinstance(success, bool)
83         assert body is None or isinstance(body, strtypes)
84
85         assert self._request_id is not None
86
87         if body is None:
88             body = ""
89
90         if body and not body.endswith("\n"):
91             body += "\n"
92
93         if success:
94             reply = Message.create_reply(body, self._request_id)
95         else:
96             reply = Message.create_error(body, self._request_id)
97
98         self._rpc.send(reply)
99         self._request_id = None
100
101     def _process_command(self, request):
102         assert isinstance(request, ovs.jsonrpc.Message)
103         assert request.type == ovs.jsonrpc.Message.T_REQUEST
104
105         self._request_id = request.id
106
107         error = None
108         params = request.params
109         method = request.method
110         command = ovs.unixctl.commands.get(method)
111         if command is None:
112             error = '"%s" is not a valid command' % method
113         elif len(params) < command.min_args:
114             error = '"%s" command requires at least %d arguments' \
115                     % (method, command.min_args)
116         elif len(params) > command.max_args:
117             error = '"%s" command takes at most %d arguments' \
118                     % (method, command.max_args)
119         else:
120             for param in params:
121                 if not isinstance(param, strtypes):
122                     error = '"%s" command has non-string argument' % method
123                     break
124
125             if error is None:
126                 unicode_params = [unicode(p) for p in params]
127                 command.callback(self, unicode_params, command.aux)
128
129         if error:
130             self.reply_error(error)
131
132
133 def _unixctl_version(conn, unused_argv, version):
134     assert isinstance(conn, UnixctlConnection)
135     version = "%s (Open vSwitch) %s" % (ovs.util.PROGRAM_NAME, version)
136     conn.reply(version)
137
138
139 class UnixctlServer(object):
140     def __init__(self, listener):
141         assert isinstance(listener, ovs.stream.PassiveStream)
142         self._listener = listener
143         self._conns = []
144
145     def run(self):
146         for _ in range(10):
147             error, stream = self._listener.accept()
148             if not error:
149                 rpc = ovs.jsonrpc.Connection(stream)
150                 self._conns.append(UnixctlConnection(rpc))
151             elif error == errno.EAGAIN:
152                 break
153             else:
154                 # XXX: rate-limit
155                 vlog.warn("%s: accept failed: %s" % (self._listener.name,
156                                                      os.strerror(error)))
157
158         for conn in copy.copy(self._conns):
159             error = conn.run()
160             if error and error != errno.EAGAIN:
161                 conn._close()
162                 self._conns.remove(conn)
163
164     def wait(self, poller):
165         self._listener.wait(poller)
166         for conn in self._conns:
167             conn._wait(poller)
168
169     def close(self):
170         for conn in self._conns:
171             conn._close()
172         self._conns = None
173
174         self._listener.close()
175         self._listener = None
176
177     @staticmethod
178     def create(path, version=None):
179         """Creates a new UnixctlServer which listens on a unixctl socket
180         created at 'path'.  If 'path' is None, the default path is chosen.
181         'version' contains the version of the server as reported by the unixctl
182         version command.  If None, ovs.version.VERSION is used."""
183
184         assert path is None or isinstance(path, strtypes)
185
186         if path is not None:
187             path = "punix:%s" % ovs.util.abs_file_name(ovs.dirs.RUNDIR, path)
188         else:
189             path = "punix:%s/%s.%d.ctl" % (ovs.dirs.RUNDIR,
190                                            ovs.util.PROGRAM_NAME, os.getpid())
191
192         if version is None:
193             version = ovs.version.VERSION
194
195         error, listener = ovs.stream.PassiveStream.open(path)
196         if error:
197             ovs.util.ovs_error(error, "could not initialize control socket %s"
198                                % path)
199             return error, None
200
201         ovs.unixctl.command_register("version", "", 0, 0, _unixctl_version,
202                                      version)
203
204         return 0, UnixctlServer(listener)
205
206
207 class UnixctlClient(object):
208     def __init__(self, conn):
209         assert isinstance(conn, ovs.jsonrpc.Connection)
210         self._conn = conn
211
212     def transact(self, command, argv):
213         assert isinstance(command, strtypes)
214         assert isinstance(argv, list)
215         for arg in argv:
216             assert isinstance(arg, strtypes)
217
218         request = Message.create_request(command, argv)
219         error, reply = self._conn.transact_block(request)
220
221         if error:
222             vlog.warn("error communicating with %s: %s"
223                       % (self._conn.name, os.strerror(error)))
224             return error, None, None
225
226         if reply.error is not None:
227             return 0, str(reply.error), None
228         else:
229             assert reply.result is not None
230             return 0, None, str(reply.result)
231
232     def close(self):
233         self._conn.close()
234         self.conn = None
235
236     @staticmethod
237     def create(path):
238         assert isinstance(path, str)
239
240         unix = "unix:%s" % ovs.util.abs_file_name(ovs.dirs.RUNDIR, path)
241         error, stream = ovs.stream.Stream.open_block(
242             ovs.stream.Stream.open(unix))
243
244         if error:
245             vlog.warn("failed to connect to %s" % path)
246             return error, None
247
248         return 0, UnixctlClient(ovs.jsonrpc.Connection(stream))