python: Drop unicode type.
[cascardo/ovs.git] / python / ovs / jsonrpc.py
1 # Copyright (c) 2010, 2011, 2012, 2013 Nicira, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import errno
16 import os
17
18 import six
19
20 import ovs.json
21 import ovs.poller
22 import ovs.reconnect
23 import ovs.stream
24 import ovs.timeval
25 import ovs.util
26 import ovs.vlog
27
28 EOF = ovs.util.EOF
29 vlog = ovs.vlog.Vlog("jsonrpc")
30
31
32 class Message(object):
33     T_REQUEST = 0               # Request.
34     T_NOTIFY = 1                # Notification.
35     T_REPLY = 2                 # Successful reply.
36     T_ERROR = 3                 # Error reply.
37
38     __types = {T_REQUEST: "request",
39                T_NOTIFY: "notification",
40                T_REPLY: "reply",
41                T_ERROR: "error"}
42
43     def __init__(self, type_, method, params, result, error, id):
44         self.type = type_
45         self.method = method
46         self.params = params
47         self.result = result
48         self.error = error
49         self.id = id
50
51     _next_id = 0
52
53     @staticmethod
54     def _create_id():
55         this_id = Message._next_id
56         Message._next_id += 1
57         return this_id
58
59     @staticmethod
60     def create_request(method, params):
61         return Message(Message.T_REQUEST, method, params, None, None,
62                        Message._create_id())
63
64     @staticmethod
65     def create_notify(method, params):
66         return Message(Message.T_NOTIFY, method, params, None, None,
67                        None)
68
69     @staticmethod
70     def create_reply(result, id):
71         return Message(Message.T_REPLY, None, None, result, None, id)
72
73     @staticmethod
74     def create_error(error, id):
75         return Message(Message.T_ERROR, None, None, None, error, id)
76
77     @staticmethod
78     def type_to_string(type_):
79         return Message.__types[type_]
80
81     def __validate_arg(self, value, name, must_have):
82         if (value is not None) == (must_have != 0):
83             return None
84         else:
85             type_name = Message.type_to_string(self.type)
86             if must_have:
87                 verb = "must"
88             else:
89                 verb = "must not"
90             return "%s %s have \"%s\"" % (type_name, verb, name)
91
92     def is_valid(self):
93         if self.params is not None and type(self.params) != list:
94             return "\"params\" must be JSON array"
95
96         pattern = {Message.T_REQUEST: 0x11001,
97                    Message.T_NOTIFY: 0x11000,
98                    Message.T_REPLY: 0x00101,
99                    Message.T_ERROR: 0x00011}.get(self.type)
100         if pattern is None:
101             return "invalid JSON-RPC message type %s" % self.type
102
103         return (
104             self.__validate_arg(self.method, "method", pattern & 0x10000) or
105             self.__validate_arg(self.params, "params", pattern & 0x1000) or
106             self.__validate_arg(self.result, "result", pattern & 0x100) or
107             self.__validate_arg(self.error, "error", pattern & 0x10) or
108             self.__validate_arg(self.id, "id", pattern & 0x1))
109
110     @staticmethod
111     def from_json(json):
112         if type(json) != dict:
113             return "message is not a JSON object"
114
115         # Make a copy to avoid modifying the caller's dict.
116         json = dict(json)
117
118         if "method" in json:
119             method = json.pop("method")
120             if not isinstance(method, six.string_types):
121                 return "method is not a JSON string"
122         else:
123             method = None
124
125         params = json.pop("params", None)
126         result = json.pop("result", None)
127         error = json.pop("error", None)
128         id_ = json.pop("id", None)
129         if len(json):
130             return "message has unexpected member \"%s\"" % json.popitem()[0]
131
132         if result is not None:
133             msg_type = Message.T_REPLY
134         elif error is not None:
135             msg_type = Message.T_ERROR
136         elif id_ is not None:
137             msg_type = Message.T_REQUEST
138         else:
139             msg_type = Message.T_NOTIFY
140
141         msg = Message(msg_type, method, params, result, error, id_)
142         validation_error = msg.is_valid()
143         if validation_error is not None:
144             return validation_error
145         else:
146             return msg
147
148     def to_json(self):
149         json = {}
150
151         if self.method is not None:
152             json["method"] = self.method
153
154         if self.params is not None:
155             json["params"] = self.params
156
157         if self.result is not None or self.type == Message.T_ERROR:
158             json["result"] = self.result
159
160         if self.error is not None or self.type == Message.T_REPLY:
161             json["error"] = self.error
162
163         if self.id is not None or self.type == Message.T_NOTIFY:
164             json["id"] = self.id
165
166         return json
167
168     def __str__(self):
169         s = [Message.type_to_string(self.type)]
170         if self.method is not None:
171             s.append("method=\"%s\"" % self.method)
172         if self.params is not None:
173             s.append("params=" + ovs.json.to_string(self.params))
174         if self.result is not None:
175             s.append("result=" + ovs.json.to_string(self.result))
176         if self.error is not None:
177             s.append("error=" + ovs.json.to_string(self.error))
178         if self.id is not None:
179             s.append("id=" + ovs.json.to_string(self.id))
180         return ", ".join(s)
181
182
183 class Connection(object):
184     def __init__(self, stream):
185         self.name = stream.name
186         self.stream = stream
187         self.status = 0
188         self.input = ""
189         self.output = ""
190         self.parser = None
191         self.received_bytes = 0
192
193     def close(self):
194         self.stream.close()
195         self.stream = None
196
197     def run(self):
198         if self.status:
199             return
200
201         while len(self.output):
202             retval = self.stream.send(self.output)
203             if retval >= 0:
204                 self.output = self.output[retval:]
205             else:
206                 if retval != -errno.EAGAIN:
207                     vlog.warn("%s: send error: %s" %
208                               (self.name, os.strerror(-retval)))
209                     self.error(-retval)
210                 break
211
212     def wait(self, poller):
213         if not self.status:
214             self.stream.run_wait(poller)
215             if len(self.output):
216                 self.stream.send_wait(poller)
217
218     def get_status(self):
219         return self.status
220
221     def get_backlog(self):
222         if self.status != 0:
223             return 0
224         else:
225             return len(self.output)
226
227     def get_received_bytes(self):
228         return self.received_bytes
229
230     def __log_msg(self, title, msg):
231         if vlog.dbg_is_enabled():
232             vlog.dbg("%s: %s %s" % (self.name, title, msg))
233
234     def send(self, msg):
235         if self.status:
236             return self.status
237
238         self.__log_msg("send", msg)
239
240         was_empty = len(self.output) == 0
241         self.output += ovs.json.to_string(msg.to_json())
242         if was_empty:
243             self.run()
244         return self.status
245
246     def send_block(self, msg):
247         error = self.send(msg)
248         if error:
249             return error
250
251         while True:
252             self.run()
253             if not self.get_backlog() or self.get_status():
254                 return self.status
255
256             poller = ovs.poller.Poller()
257             self.wait(poller)
258             poller.block()
259
260     def recv(self):
261         if self.status:
262             return self.status, None
263
264         while True:
265             if not self.input:
266                 error, data = self.stream.recv(4096)
267                 if error:
268                     if error == errno.EAGAIN:
269                         return error, None
270                     else:
271                         # XXX rate-limit
272                         vlog.warn("%s: receive error: %s"
273                                   % (self.name, os.strerror(error)))
274                         self.error(error)
275                         return self.status, None
276                 elif not data:
277                     self.error(EOF)
278                     return EOF, None
279                 else:
280                     self.input += data
281                     self.received_bytes += len(data)
282             else:
283                 if self.parser is None:
284                     self.parser = ovs.json.Parser()
285                 self.input = self.input[self.parser.feed(self.input):]
286                 if self.parser.is_done():
287                     msg = self.__process_msg()
288                     if msg:
289                         return 0, msg
290                     else:
291                         return self.status, None
292
293     def recv_block(self):
294         while True:
295             error, msg = self.recv()
296             if error != errno.EAGAIN:
297                 return error, msg
298
299             self.run()
300
301             poller = ovs.poller.Poller()
302             self.wait(poller)
303             self.recv_wait(poller)
304             poller.block()
305
306     def transact_block(self, request):
307         id_ = request.id
308
309         error = self.send(request)
310         reply = None
311         while not error:
312             error, reply = self.recv_block()
313             if (reply
314                 and (reply.type == Message.T_REPLY
315                      or reply.type == Message.T_ERROR)
316                 and reply.id == id_):
317                 break
318         return error, reply
319
320     def __process_msg(self):
321         json = self.parser.finish()
322         self.parser = None
323         if isinstance(json, six.string_types):
324             # XXX rate-limit
325             vlog.warn("%s: error parsing stream: %s" % (self.name, json))
326             self.error(errno.EPROTO)
327             return
328
329         msg = Message.from_json(json)
330         if not isinstance(msg, Message):
331             # XXX rate-limit
332             vlog.warn("%s: received bad JSON-RPC message: %s"
333                       % (self.name, msg))
334             self.error(errno.EPROTO)
335             return
336
337         self.__log_msg("received", msg)
338         return msg
339
340     def recv_wait(self, poller):
341         if self.status or self.input:
342             poller.immediate_wake()
343         else:
344             self.stream.recv_wait(poller)
345
346     def error(self, error):
347         if self.status == 0:
348             self.status = error
349             self.stream.close()
350             self.output = ""
351
352
353 class Session(object):
354     """A JSON-RPC session with reconnection."""
355
356     def __init__(self, reconnect, rpc):
357         self.reconnect = reconnect
358         self.rpc = rpc
359         self.stream = None
360         self.pstream = None
361         self.seqno = 0
362
363     @staticmethod
364     def open(name):
365         """Creates and returns a Session that maintains a JSON-RPC session to
366         'name', which should be a string acceptable to ovs.stream.Stream or
367         ovs.stream.PassiveStream's initializer.
368
369         If 'name' is an active connection method, e.g. "tcp:127.1.2.3", the new
370         session connects and reconnects, with back-off, to 'name'.
371
372         If 'name' is a passive connection method, e.g. "ptcp:", the new session
373         listens for connections to 'name'.  It maintains at most one connection
374         at any given time.  Any new connection causes the previous one (if any)
375         to be dropped."""
376         reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec())
377         reconnect.set_name(name)
378         reconnect.enable(ovs.timeval.msec())
379
380         if ovs.stream.PassiveStream.is_valid_name(name):
381             reconnect.set_passive(True, ovs.timeval.msec())
382
383         if not ovs.stream.stream_or_pstream_needs_probes(name):
384             reconnect.set_probe_interval(0)
385
386         return Session(reconnect, None)
387
388     @staticmethod
389     def open_unreliably(jsonrpc):
390         reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec())
391         reconnect.set_quiet(True)
392         reconnect.set_name(jsonrpc.name)
393         reconnect.set_max_tries(0)
394         reconnect.connected(ovs.timeval.msec())
395         return Session(reconnect, jsonrpc)
396
397     def close(self):
398         if self.rpc is not None:
399             self.rpc.close()
400             self.rpc = None
401         if self.stream is not None:
402             self.stream.close()
403             self.stream = None
404         if self.pstream is not None:
405             self.pstream.close()
406             self.pstream = None
407
408     def __disconnect(self):
409         if self.rpc is not None:
410             self.rpc.error(EOF)
411             self.rpc.close()
412             self.rpc = None
413             self.seqno += 1
414         elif self.stream is not None:
415             self.stream.close()
416             self.stream = None
417             self.seqno += 1
418
419     def __connect(self):
420         self.__disconnect()
421
422         name = self.reconnect.get_name()
423         if not self.reconnect.is_passive():
424             error, self.stream = ovs.stream.Stream.open(name)
425             if not error:
426                 self.reconnect.connecting(ovs.timeval.msec())
427             else:
428                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
429         elif self.pstream is not None:
430             error, self.pstream = ovs.stream.PassiveStream.open(name)
431             if not error:
432                 self.reconnect.listening(ovs.timeval.msec())
433             else:
434                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
435
436         self.seqno += 1
437
438     def run(self):
439         if self.pstream is not None:
440             error, stream = self.pstream.accept()
441             if error == 0:
442                 if self.rpc or self.stream:
443                     # XXX rate-limit
444                     vlog.info("%s: new connection replacing active "
445                               "connection" % self.reconnect.get_name())
446                     self.__disconnect()
447                 self.reconnect.connected(ovs.timeval.msec())
448                 self.rpc = Connection(stream)
449             elif error != errno.EAGAIN:
450                 self.reconnect.listen_error(ovs.timeval.msec(), error)
451                 self.pstream.close()
452                 self.pstream = None
453
454         if self.rpc:
455             backlog = self.rpc.get_backlog()
456             self.rpc.run()
457             if self.rpc.get_backlog() < backlog:
458                 # Data previously caught in a queue was successfully sent (or
459                 # there's an error, which we'll catch below).
460                 #
461                 # We don't count data that is successfully sent immediately as
462                 # activity, because there's a lot of queuing downstream from
463                 # us, which means that we can push a lot of data into a
464                 # connection that has stalled and won't ever recover.
465                 self.reconnect.activity(ovs.timeval.msec())
466
467             error = self.rpc.get_status()
468             if error != 0:
469                 self.reconnect.disconnected(ovs.timeval.msec(), error)
470                 self.__disconnect()
471         elif self.stream is not None:
472             self.stream.run()
473             error = self.stream.connect()
474             if error == 0:
475                 self.reconnect.connected(ovs.timeval.msec())
476                 self.rpc = Connection(self.stream)
477                 self.stream = None
478             elif error != errno.EAGAIN:
479                 self.reconnect.connect_failed(ovs.timeval.msec(), error)
480                 self.stream.close()
481                 self.stream = None
482
483         action = self.reconnect.run(ovs.timeval.msec())
484         if action == ovs.reconnect.CONNECT:
485             self.__connect()
486         elif action == ovs.reconnect.DISCONNECT:
487             self.reconnect.disconnected(ovs.timeval.msec(), 0)
488             self.__disconnect()
489         elif action == ovs.reconnect.PROBE:
490             if self.rpc:
491                 request = Message.create_request("echo", [])
492                 request.id = "echo"
493                 self.rpc.send(request)
494         else:
495             assert action is None
496
497     def wait(self, poller):
498         if self.rpc is not None:
499             self.rpc.wait(poller)
500         elif self.stream is not None:
501             self.stream.run_wait(poller)
502             self.stream.connect_wait(poller)
503         if self.pstream is not None:
504             self.pstream.wait(poller)
505         self.reconnect.wait(poller, ovs.timeval.msec())
506
507     def get_backlog(self):
508         if self.rpc is not None:
509             return self.rpc.get_backlog()
510         else:
511             return 0
512
513     def get_name(self):
514         return self.reconnect.get_name()
515
516     def send(self, msg):
517         if self.rpc is not None:
518             return self.rpc.send(msg)
519         else:
520             return errno.ENOTCONN
521
522     def recv(self):
523         if self.rpc is not None:
524             received_bytes = self.rpc.get_received_bytes()
525             error, msg = self.rpc.recv()
526             if received_bytes != self.rpc.get_received_bytes():
527                 # Data was successfully received.
528                 #
529                 # Previously we only counted receiving a full message as
530                 # activity, but with large messages or a slow connection that
531                 # policy could time out the session mid-message.
532                 self.reconnect.activity(ovs.timeval.msec())
533
534             if not error:
535                 if msg.type == Message.T_REQUEST and msg.method == "echo":
536                     # Echo request.  Send reply.
537                     self.send(Message.create_reply(msg.params, msg.id))
538                 elif msg.type == Message.T_REPLY and msg.id == "echo":
539                     # It's a reply to our echo request.  Suppress it.
540                     pass
541                 else:
542                     return msg
543         return None
544
545     def recv_wait(self, poller):
546         if self.rpc is not None:
547             self.rpc.recv_wait(poller)
548
549     def is_alive(self):
550         if self.rpc is not None or self.stream is not None:
551             return True
552         else:
553             max_tries = self.reconnect.get_max_tries()
554             return max_tries is None or max_tries > 0
555
556     def is_connected(self):
557         return self.rpc is not None
558
559     def get_seqno(self):
560         return self.seqno
561
562     def force_reconnect(self):
563         self.reconnect.force_reconnect(ovs.timeval.msec())