testsuite: Add timeout to add_of_br() command.
[cascardo/ovs.git] / python / ovs / socket_util.py
1 # Copyright (c) 2010, 2012, 2014, 2015 Nicira, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import errno
16 import os
17 import os.path
18 import random
19 import socket
20 import sys
21
22 import six
23 from six.moves import range
24
25 import ovs.fatal_signal
26 import ovs.poller
27 import ovs.vlog
28
29 vlog = ovs.vlog.Vlog("socket_util")
30
31
32 def make_short_name(long_name):
33     if long_name is None:
34         return None
35     long_name = os.path.abspath(long_name)
36     long_dirname = os.path.dirname(long_name)
37     tmpdir = os.getenv('TMPDIR', '/tmp')
38     for x in range(0, 1000):
39         link_name = \
40             '%s/ovs-un-py-%d-%d' % (tmpdir, random.randint(0, 10000), x)
41         try:
42             os.symlink(long_dirname, link_name)
43             ovs.fatal_signal.add_file_to_unlink(link_name)
44             return os.path.join(link_name, os.path.basename(long_name))
45         except OSError as e:
46             if e.errno != errno.EEXIST:
47                 break
48     raise Exception("Failed to create temporary symlink")
49
50
51 def free_short_name(short_name):
52     if short_name is None:
53         return
54     link_name = os.path.dirname(short_name)
55     ovs.fatal_signal.unlink_file_now(link_name)
56
57
58 def make_unix_socket(style, nonblock, bind_path, connect_path, short=False):
59     """Creates a Unix domain socket in the given 'style' (either
60     socket.SOCK_DGRAM or socket.SOCK_STREAM) that is bound to 'bind_path' (if
61     'bind_path' is not None) and connected to 'connect_path' (if 'connect_path'
62     is not None).  If 'nonblock' is true, the socket is made non-blocking.
63
64     Returns (error, socket): on success 'error' is 0 and 'socket' is a new
65     socket object, on failure 'error' is a positive errno value and 'socket' is
66     None."""
67
68     try:
69         sock = socket.socket(socket.AF_UNIX, style)
70     except socket.error as e:
71         return get_exception_errno(e), None
72
73     try:
74         if nonblock:
75             set_nonblocking(sock)
76         if bind_path is not None:
77             # Delete bind_path but ignore ENOENT.
78             try:
79                 os.unlink(bind_path)
80             except OSError as e:
81                 if e.errno != errno.ENOENT:
82                     return e.errno, None
83
84             ovs.fatal_signal.add_file_to_unlink(bind_path)
85             sock.bind(bind_path)
86
87             try:
88                 if sys.hexversion >= 0x02060000:
89                     os.fchmod(sock.fileno(), 0o700)
90                 else:
91                     os.chmod("/dev/fd/%d" % sock.fileno(), 0o700)
92             except OSError as e:
93                 pass
94         if connect_path is not None:
95             try:
96                 sock.connect(connect_path)
97             except socket.error as e:
98                 if get_exception_errno(e) != errno.EINPROGRESS:
99                     raise
100         return 0, sock
101     except socket.error as e:
102         sock.close()
103         if (bind_path is not None and
104             os.path.exists(bind_path)):
105             ovs.fatal_signal.unlink_file_now(bind_path)
106         eno = ovs.socket_util.get_exception_errno(e)
107         if (eno == "AF_UNIX path too long" and
108             os.uname()[0] == "Linux"):
109             short_connect_path = None
110             short_bind_path = None
111             connect_dirfd = None
112             bind_dirfd = None
113             # Try workaround using /proc/self/fd
114             if connect_path is not None:
115                 dirname = os.path.dirname(connect_path)
116                 basename = os.path.basename(connect_path)
117                 try:
118                     connect_dirfd = os.open(dirname,
119                                             os.O_DIRECTORY | os.O_RDONLY)
120                 except OSError as err:
121                     return get_exception_errno(err), None
122                 short_connect_path = "/proc/self/fd/%d/%s" % (connect_dirfd,
123                                                               basename)
124
125             if bind_path is not None:
126                 dirname = os.path.dirname(bind_path)
127                 basename = os.path.basename(bind_path)
128                 try:
129                     bind_dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY)
130                 except OSError as err:
131                     return get_exception_errno(err), None
132                 short_bind_path = "/proc/self/fd/%d/%s" % (bind_dirfd,
133                                                            basename)
134
135             try:
136                 return make_unix_socket(style, nonblock, short_bind_path,
137                                         short_connect_path)
138             finally:
139                 if connect_dirfd is not None:
140                     os.close(connect_dirfd)
141                 if bind_dirfd is not None:
142                     os.close(bind_dirfd)
143         elif (eno == "AF_UNIX path too long"):
144             if short:
145                 return get_exception_errno(e), None
146             short_bind_path = None
147             try:
148                 short_bind_path = make_short_name(bind_path)
149                 short_connect_path = make_short_name(connect_path)
150             except:
151                 free_short_name(short_bind_path)
152                 return errno.ENAMETOOLONG, None
153             try:
154                 return make_unix_socket(style, nonblock, short_bind_path,
155                                         short_connect_path, short=True)
156             finally:
157                 free_short_name(short_bind_path)
158                 free_short_name(short_connect_path)
159         else:
160             return get_exception_errno(e), None
161
162
163 def check_connection_completion(sock):
164     p = ovs.poller.SelectPoll()
165     p.register(sock, ovs.poller.POLLOUT)
166     pfds = p.poll(0)
167     if len(pfds) == 1:
168         revents = pfds[0][1]
169         if revents & ovs.poller.POLLERR:
170             try:
171                 # The following should raise an exception.
172                 socket.send("\0", socket.MSG_DONTWAIT)
173
174                 # (Here's where we end up if it didn't.)
175                 # XXX rate-limit
176                 vlog.err("poll return POLLERR but send succeeded")
177                 return errno.EPROTO
178             except socket.error as e:
179                 return get_exception_errno(e)
180         else:
181             return 0
182     else:
183         return errno.EAGAIN
184
185
186 def is_valid_ipv4_address(address):
187     try:
188         socket.inet_pton(socket.AF_INET, address)
189     except AttributeError:
190         try:
191             socket.inet_aton(address)
192         except socket.error:
193             return False
194     except socket.error:
195         return False
196
197     return True
198
199
200 def inet_parse_active(target, default_port):
201     address = target.split(":")
202     if len(address) >= 2:
203         host_name = ":".join(address[0:-1]).lstrip('[').rstrip(']')
204         port = int(address[-1])
205     else:
206         if default_port:
207             port = default_port
208         else:
209             raise ValueError("%s: port number must be specified" % target)
210         host_name = address[0]
211     if not host_name:
212         raise ValueError("%s: bad peer name format" % target)
213     return (host_name, port)
214
215
216 def inet_open_active(style, target, default_port, dscp):
217     address = inet_parse_active(target, default_port)
218     try:
219         is_addr_inet = is_valid_ipv4_address(address[0])
220         if is_addr_inet:
221             sock = socket.socket(socket.AF_INET, style, 0)
222             family = socket.AF_INET
223         else:
224             sock = socket.socket(socket.AF_INET6, style, 0)
225             family = socket.AF_INET6
226     except socket.error as e:
227         return get_exception_errno(e), None
228
229     try:
230         set_nonblocking(sock)
231         set_dscp(sock, family, dscp)
232         try:
233             sock.connect(address)
234         except socket.error as e:
235             if get_exception_errno(e) != errno.EINPROGRESS:
236                 raise
237         return 0, sock
238     except socket.error as e:
239         sock.close()
240         return get_exception_errno(e), None
241
242
243 def get_exception_errno(e):
244     """A lot of methods on Python socket objects raise socket.error, but that
245     exception is documented as having two completely different forms of
246     arguments: either a string or a (errno, string) tuple.  We only want the
247     errno."""
248     if isinstance(e.args, tuple):
249         return e.args[0]
250     else:
251         return errno.EPROTO
252
253
254 null_fd = -1
255
256
257 def get_null_fd():
258     """Returns a readable and writable fd for /dev/null, if successful,
259     otherwise a negative errno value.  The caller must not close the returned
260     fd (because the same fd will be handed out to subsequent callers)."""
261     global null_fd
262     if null_fd < 0:
263         try:
264             null_fd = os.open("/dev/null", os.O_RDWR)
265         except OSError as e:
266             vlog.err("could not open /dev/null: %s" % os.strerror(e.errno))
267             return -e.errno
268     return null_fd
269
270
271 def write_fully(fd, buf):
272     """Returns an (error, bytes_written) tuple where 'error' is 0 on success,
273     otherwise a positive errno value, and 'bytes_written' is the number of
274     bytes that were written before the error occurred.  'error' is 0 if and
275     only if 'bytes_written' is len(buf)."""
276     bytes_written = 0
277     if len(buf) == 0:
278         return 0, 0
279     if sys.version_info[0] >= 3 and not isinstance(buf, six.binary_type):
280         buf = six.binary_type(buf, 'utf-8')
281     while True:
282         try:
283             retval = os.write(fd, buf)
284             assert retval >= 0
285             if retval == len(buf):
286                 return 0, bytes_written + len(buf)
287             elif retval == 0:
288                 vlog.warn("write returned 0")
289                 return errno.EPROTO, bytes_written
290             else:
291                 bytes_written += retval
292                 buf = buf[:retval]
293         except OSError as e:
294             return e.errno, bytes_written
295
296
297 def set_nonblocking(sock):
298     try:
299         sock.setblocking(0)
300     except socket.error as e:
301         vlog.err("could not set nonblocking mode on socket: %s"
302                  % os.strerror(get_exception_errno(e)))
303
304
305 def set_dscp(sock, family, dscp):
306     if dscp > 63:
307         raise ValueError("Invalid dscp %d" % dscp)
308
309     val = dscp << 2
310     if family == socket.AF_INET:
311         sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, val)
312     elif family == socket.AF_INET6:
313         sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_TCLASS, val)
314     else:
315         raise ValueError('Invalid family %d' % family)