python: Remove reamining direct type comparisons.
[cascardo/ovs.git] / python / ovs / socket_util.py
1 # Copyright (c) 2010, 2012, 2014, 2015 Nicira, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at:
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14
15 import errno
16 import os
17 import os.path
18 import random
19 import socket
20 import sys
21
22 from six.moves import range
23
24 import ovs.fatal_signal
25 import ovs.poller
26 import ovs.vlog
27
28 vlog = ovs.vlog.Vlog("socket_util")
29
30
31 def make_short_name(long_name):
32     if long_name is None:
33         return None
34     long_name = os.path.abspath(long_name)
35     long_dirname = os.path.dirname(long_name)
36     tmpdir = os.getenv('TMPDIR', '/tmp')
37     for x in range(0, 1000):
38         link_name = \
39             '%s/ovs-un-py-%d-%d' % (tmpdir, random.randint(0, 10000), x)
40         try:
41             os.symlink(long_dirname, link_name)
42             ovs.fatal_signal.add_file_to_unlink(link_name)
43             return os.path.join(link_name, os.path.basename(long_name))
44         except OSError as e:
45             if e.errno != errno.EEXIST:
46                 break
47     raise Exception("Failed to create temporary symlink")
48
49
50 def free_short_name(short_name):
51     if short_name is None:
52         return
53     link_name = os.path.dirname(short_name)
54     ovs.fatal_signal.unlink_file_now(link_name)
55
56
57 def make_unix_socket(style, nonblock, bind_path, connect_path, short=False):
58     """Creates a Unix domain socket in the given 'style' (either
59     socket.SOCK_DGRAM or socket.SOCK_STREAM) that is bound to 'bind_path' (if
60     'bind_path' is not None) and connected to 'connect_path' (if 'connect_path'
61     is not None).  If 'nonblock' is true, the socket is made non-blocking.
62
63     Returns (error, socket): on success 'error' is 0 and 'socket' is a new
64     socket object, on failure 'error' is a positive errno value and 'socket' is
65     None."""
66
67     try:
68         sock = socket.socket(socket.AF_UNIX, style)
69     except socket.error as e:
70         return get_exception_errno(e), None
71
72     try:
73         if nonblock:
74             set_nonblocking(sock)
75         if bind_path is not None:
76             # Delete bind_path but ignore ENOENT.
77             try:
78                 os.unlink(bind_path)
79             except OSError as e:
80                 if e.errno != errno.ENOENT:
81                     return e.errno, None
82
83             ovs.fatal_signal.add_file_to_unlink(bind_path)
84             sock.bind(bind_path)
85
86             try:
87                 if sys.hexversion >= 0x02060000:
88                     os.fchmod(sock.fileno(), 0o700)
89                 else:
90                     os.chmod("/dev/fd/%d" % sock.fileno(), 0o700)
91             except OSError as e:
92                 pass
93         if connect_path is not None:
94             try:
95                 sock.connect(connect_path)
96             except socket.error as e:
97                 if get_exception_errno(e) != errno.EINPROGRESS:
98                     raise
99         return 0, sock
100     except socket.error as e:
101         sock.close()
102         if (bind_path is not None and
103             os.path.exists(bind_path)):
104             ovs.fatal_signal.unlink_file_now(bind_path)
105         eno = ovs.socket_util.get_exception_errno(e)
106         if (eno == "AF_UNIX path too long" and
107             os.uname()[0] == "Linux"):
108             short_connect_path = None
109             short_bind_path = None
110             connect_dirfd = None
111             bind_dirfd = None
112             # Try workaround using /proc/self/fd
113             if connect_path is not None:
114                 dirname = os.path.dirname(connect_path)
115                 basename = os.path.basename(connect_path)
116                 try:
117                     connect_dirfd = os.open(dirname,
118                                             os.O_DIRECTORY | os.O_RDONLY)
119                 except OSError as err:
120                     return get_exception_errno(err), None
121                 short_connect_path = "/proc/self/fd/%d/%s" % (connect_dirfd,
122                                                               basename)
123
124             if bind_path is not None:
125                 dirname = os.path.dirname(bind_path)
126                 basename = os.path.basename(bind_path)
127                 try:
128                     bind_dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY)
129                 except OSError as err:
130                     return get_exception_errno(err), None
131                 short_bind_path = "/proc/self/fd/%d/%s" % (bind_dirfd,
132                                                            basename)
133
134             try:
135                 return make_unix_socket(style, nonblock, short_bind_path,
136                                         short_connect_path)
137             finally:
138                 if connect_dirfd is not None:
139                     os.close(connect_dirfd)
140                 if bind_dirfd is not None:
141                     os.close(bind_dirfd)
142         elif (eno == "AF_UNIX path too long"):
143             if short:
144                 return get_exception_errno(e), None
145             short_bind_path = None
146             try:
147                 short_bind_path = make_short_name(bind_path)
148                 short_connect_path = make_short_name(connect_path)
149             except:
150                 free_short_name(short_bind_path)
151                 return errno.ENAMETOOLONG, None
152             try:
153                 return make_unix_socket(style, nonblock, short_bind_path,
154                                         short_connect_path, short=True)
155             finally:
156                 free_short_name(short_bind_path)
157                 free_short_name(short_connect_path)
158         else:
159             return get_exception_errno(e), None
160
161
162 def check_connection_completion(sock):
163     p = ovs.poller.SelectPoll()
164     p.register(sock, ovs.poller.POLLOUT)
165     pfds = p.poll(0)
166     if len(pfds) == 1:
167         revents = pfds[0][1]
168         if revents & ovs.poller.POLLERR:
169             try:
170                 # The following should raise an exception.
171                 socket.send("\0", socket.MSG_DONTWAIT)
172
173                 # (Here's where we end up if it didn't.)
174                 # XXX rate-limit
175                 vlog.err("poll return POLLERR but send succeeded")
176                 return errno.EPROTO
177             except socket.error as e:
178                 return get_exception_errno(e)
179         else:
180             return 0
181     else:
182         return errno.EAGAIN
183
184
185 def is_valid_ipv4_address(address):
186     try:
187         socket.inet_pton(socket.AF_INET, address)
188     except AttributeError:
189         try:
190             socket.inet_aton(address)
191         except socket.error:
192             return False
193     except socket.error:
194         return False
195
196     return True
197
198
199 def inet_parse_active(target, default_port):
200     address = target.split(":")
201     if len(address) >= 2:
202         host_name = ":".join(address[0:-1]).lstrip('[').rstrip(']')
203         port = int(address[-1])
204     else:
205         if default_port:
206             port = default_port
207         else:
208             raise ValueError("%s: port number must be specified" % target)
209         host_name = address[0]
210     if not host_name:
211         raise ValueError("%s: bad peer name format" % target)
212     return (host_name, port)
213
214
215 def inet_open_active(style, target, default_port, dscp):
216     address = inet_parse_active(target, default_port)
217     try:
218         is_addr_inet = is_valid_ipv4_address(address[0])
219         if is_addr_inet:
220             sock = socket.socket(socket.AF_INET, style, 0)
221             family = socket.AF_INET
222         else:
223             sock = socket.socket(socket.AF_INET6, style, 0)
224             family = socket.AF_INET6
225     except socket.error as e:
226         return get_exception_errno(e), None
227
228     try:
229         set_nonblocking(sock)
230         set_dscp(sock, family, dscp)
231         try:
232             sock.connect(address)
233         except socket.error as e:
234             if get_exception_errno(e) != errno.EINPROGRESS:
235                 raise
236         return 0, sock
237     except socket.error as e:
238         sock.close()
239         return get_exception_errno(e), None
240
241
242 def get_exception_errno(e):
243     """A lot of methods on Python socket objects raise socket.error, but that
244     exception is documented as having two completely different forms of
245     arguments: either a string or a (errno, string) tuple.  We only want the
246     errno."""
247     if isinstance(e.args, tuple):
248         return e.args[0]
249     else:
250         return errno.EPROTO
251
252
253 null_fd = -1
254
255
256 def get_null_fd():
257     """Returns a readable and writable fd for /dev/null, if successful,
258     otherwise a negative errno value.  The caller must not close the returned
259     fd (because the same fd will be handed out to subsequent callers)."""
260     global null_fd
261     if null_fd < 0:
262         try:
263             null_fd = os.open("/dev/null", os.O_RDWR)
264         except OSError as e:
265             vlog.err("could not open /dev/null: %s" % os.strerror(e.errno))
266             return -e.errno
267     return null_fd
268
269
270 def write_fully(fd, buf):
271     """Returns an (error, bytes_written) tuple where 'error' is 0 on success,
272     otherwise a positive errno value, and 'bytes_written' is the number of
273     bytes that were written before the error occurred.  'error' is 0 if and
274     only if 'bytes_written' is len(buf)."""
275     bytes_written = 0
276     if len(buf) == 0:
277         return 0, 0
278     while True:
279         try:
280             retval = os.write(fd, buf)
281             assert retval >= 0
282             if retval == len(buf):
283                 return 0, bytes_written + len(buf)
284             elif retval == 0:
285                 vlog.warn("write returned 0")
286                 return errno.EPROTO, bytes_written
287             else:
288                 bytes_written += retval
289                 buf = buf[:retval]
290         except OSError as e:
291             return e.errno, bytes_written
292
293
294 def set_nonblocking(sock):
295     try:
296         sock.setblocking(0)
297     except socket.error as e:
298         vlog.err("could not set nonblocking mode on socket: %s"
299                  % os.strerror(get_exception_errno(e)))
300
301
302 def set_dscp(sock, family, dscp):
303     if dscp > 63:
304         raise ValueError("Invalid dscp %d" % dscp)
305
306     val = dscp << 2
307     if family == socket.AF_INET:
308         sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, val)
309     elif family == socket.AF_INET6:
310         sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_TCLASS, val)
311     else:
312         raise ValueError('Invalid family %d' % family)