odp-util: Format and scan multiple MPLS labels.
[cascardo/ovs.git] / python / ovs / socket_util.py
index 7bfefc4..9f46a55 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (c) 2010, 2012 Nicira, Inc.
+# Copyright (c) 2010, 2012, 2014, 2015 Nicira, Inc.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
 
 import errno
 import os
-import select
+import os.path
+import random
 import socket
 import sys
 
+import six
+from six.moves import range
+
 import ovs.fatal_signal
 import ovs.poller
 import ovs.vlog
@@ -25,7 +29,33 @@ import ovs.vlog
 vlog = ovs.vlog.Vlog("socket_util")
 
 
-def make_unix_socket(style, nonblock, bind_path, connect_path):
+def make_short_name(long_name):
+    if long_name is None:
+        return None
+    long_name = os.path.abspath(long_name)
+    long_dirname = os.path.dirname(long_name)
+    tmpdir = os.getenv('TMPDIR', '/tmp')
+    for x in range(0, 1000):
+        link_name = \
+            '%s/ovs-un-py-%d-%d' % (tmpdir, random.randint(0, 10000), x)
+        try:
+            os.symlink(long_dirname, link_name)
+            ovs.fatal_signal.add_file_to_unlink(link_name)
+            return os.path.join(link_name, os.path.basename(long_name))
+        except OSError as e:
+            if e.errno != errno.EEXIST:
+                break
+    raise Exception("Failed to create temporary symlink")
+
+
+def free_short_name(short_name):
+    if short_name is None:
+        return
+    link_name = os.path.dirname(short_name)
+    ovs.fatal_signal.unlink_file_now(link_name)
+
+
+def make_unix_socket(style, nonblock, bind_path, connect_path, short=False):
     """Creates a Unix domain socket in the given 'style' (either
     socket.SOCK_DGRAM or socket.SOCK_STREAM) that is bound to 'bind_path' (if
     'bind_path' is not None) and connected to 'connect_path' (if 'connect_path'
@@ -37,7 +67,7 @@ def make_unix_socket(style, nonblock, bind_path, connect_path):
 
     try:
         sock = socket.socket(socket.AF_UNIX, style)
-    except socket.error, e:
+    except socket.error as e:
         return get_exception_errno(e), None
 
     try:
@@ -47,7 +77,7 @@ def make_unix_socket(style, nonblock, bind_path, connect_path):
             # Delete bind_path but ignore ENOENT.
             try:
                 os.unlink(bind_path)
-            except OSError, e:
+            except OSError as e:
                 if e.errno != errno.ENOENT:
                     return e.errno, None
 
@@ -56,19 +86,19 @@ def make_unix_socket(style, nonblock, bind_path, connect_path):
 
             try:
                 if sys.hexversion >= 0x02060000:
-                    os.fchmod(sock.fileno(), 0700)
+                    os.fchmod(sock.fileno(), 0o700)
                 else:
-                    os.chmod("/dev/fd/%d" % sock.fileno(), 0700)
-            except OSError, e:
+                    os.chmod("/dev/fd/%d" % sock.fileno(), 0o700)
+            except OSError as e:
                 pass
         if connect_path is not None:
             try:
                 sock.connect(connect_path)
-            except socket.error, e:
+            except socket.error as e:
                 if get_exception_errno(e) != errno.EINPROGRESS:
                     raise
         return 0, sock
-    except socket.error, e:
+    except socket.error as e:
         sock.close()
         if (bind_path is not None and
             os.path.exists(bind_path)):
@@ -85,27 +115,47 @@ def make_unix_socket(style, nonblock, bind_path, connect_path):
                 dirname = os.path.dirname(connect_path)
                 basename = os.path.basename(connect_path)
                 try:
-                    connect_dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY)
-                except OSError, err:
-                    return get_exception_errno(e), None
-                short_connect_path = "/proc/self/fd/%d/%s" % (connect_dirfd, basename)
+                    connect_dirfd = os.open(dirname,
+                                            os.O_DIRECTORY | os.O_RDONLY)
+                except OSError as err:
+                    return get_exception_errno(err), None
+                short_connect_path = "/proc/self/fd/%d/%s" % (connect_dirfd,
+                                                              basename)
 
             if bind_path is not None:
                 dirname = os.path.dirname(bind_path)
                 basename = os.path.basename(bind_path)
                 try:
                     bind_dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY)
-                except OSError, err:
-                    return get_exception_errno(e), None
-                short_bind_path = "/proc/self/fd/%d/%s" % (bind_dirfd, basename)
+                except OSError as err:
+                    return get_exception_errno(err), None
+                short_bind_path = "/proc/self/fd/%d/%s" % (bind_dirfd,
+                                                           basename)
 
             try:
-                return make_unix_socket(style, nonblock, short_bind_path, short_connect_path)
+                return make_unix_socket(style, nonblock, short_bind_path,
+                                        short_connect_path)
             finally:
                 if connect_dirfd is not None:
                     os.close(connect_dirfd)
                 if bind_dirfd is not None:
                     os.close(bind_dirfd)
+        elif (eno == "AF_UNIX path too long"):
+            if short:
+                return get_exception_errno(e), None
+            short_bind_path = None
+            try:
+                short_bind_path = make_short_name(bind_path)
+                short_connect_path = make_short_name(connect_path)
+            except:
+                free_short_name(short_bind_path)
+                return errno.ENAMETOOLONG, None
+            try:
+                return make_unix_socket(style, nonblock, short_bind_path,
+                                        short_connect_path, short=True)
+            finally:
+                free_short_name(short_bind_path)
+                free_short_name(short_connect_path)
         else:
             return get_exception_errno(e), None
 
@@ -125,7 +175,7 @@ def check_connection_completion(sock):
                 # XXX rate-limit
                 vlog.err("poll return POLLERR but send succeeded")
                 return errno.EPROTO
-            except socket.error, e:
+            except socket.error as e:
                 return get_exception_errno(e)
         else:
             return 0
@@ -133,37 +183,59 @@ def check_connection_completion(sock):
         return errno.EAGAIN
 
 
+def is_valid_ipv4_address(address):
+    try:
+        socket.inet_pton(socket.AF_INET, address)
+    except AttributeError:
+        try:
+            socket.inet_aton(address)
+        except socket.error:
+            return False
+    except socket.error:
+        return False
+
+    return True
+
+
 def inet_parse_active(target, default_port):
     address = target.split(":")
-    host_name = address[0]
-    if not host_name:
-        raise ValueError("%s: bad peer name format" % target)
     if len(address) >= 2:
-        port = int(address[1])
-    elif default_port:
-        port = default_port
+        host_name = ":".join(address[0:-1]).lstrip('[').rstrip(']')
+        port = int(address[-1])
     else:
-        raise ValueError("%s: port number must be specified" % target)
+        if default_port:
+            port = default_port
+        else:
+            raise ValueError("%s: port number must be specified" % target)
+        host_name = address[0]
+    if not host_name:
+        raise ValueError("%s: bad peer name format" % target)
     return (host_name, port)
 
 
 def inet_open_active(style, target, default_port, dscp):
     address = inet_parse_active(target, default_port)
     try:
-        sock = socket.socket(socket.AF_INET, style, 0)
-    except socket.error, e:
+        is_addr_inet = is_valid_ipv4_address(address[0])
+        if is_addr_inet:
+            sock = socket.socket(socket.AF_INET, style, 0)
+            family = socket.AF_INET
+        else:
+            sock = socket.socket(socket.AF_INET6, style, 0)
+            family = socket.AF_INET6
+    except socket.error as e:
         return get_exception_errno(e), None
 
     try:
         set_nonblocking(sock)
-        set_dscp(sock, dscp)
+        set_dscp(sock, family, dscp)
         try:
             sock.connect(address)
-        except socket.error, e:
+        except socket.error as e:
             if get_exception_errno(e) != errno.EINPROGRESS:
                 raise
         return 0, sock
-    except socket.error, e:
+    except socket.error as e:
         sock.close()
         return get_exception_errno(e), None
 
@@ -173,7 +245,7 @@ def get_exception_errno(e):
     exception is documented as having two completely different forms of
     arguments: either a string or a (errno, string) tuple.  We only want the
     errno."""
-    if type(e.args) == tuple:
+    if isinstance(e.args, tuple):
         return e.args[0]
     else:
         return errno.EPROTO
@@ -190,7 +262,7 @@ def get_null_fd():
     if null_fd < 0:
         try:
             null_fd = os.open("/dev/null", os.O_RDWR)
-        except OSError, e:
+        except OSError as e:
             vlog.err("could not open /dev/null: %s" % os.strerror(e.errno))
             return -e.errno
     return null_fd
@@ -204,6 +276,8 @@ def write_fully(fd, buf):
     bytes_written = 0
     if len(buf) == 0:
         return 0, 0
+    if sys.version_info[0] >= 3 and not isinstance(buf, six.binary_type):
+        buf = six.binary_type(buf, 'utf-8')
     while True:
         try:
             retval = os.write(fd, buf)
@@ -216,20 +290,26 @@ def write_fully(fd, buf):
             else:
                 bytes_written += retval
                 buf = buf[:retval]
-        except OSError, e:
+        except OSError as e:
             return e.errno, bytes_written
 
 
 def set_nonblocking(sock):
     try:
         sock.setblocking(0)
-    except socket.error, e:
+    except socket.error as e:
         vlog.err("could not set nonblocking mode on socket: %s"
                  % os.strerror(get_exception_errno(e)))
 
 
-def set_dscp(sock, dscp):
+def set_dscp(sock, family, dscp):
     if dscp > 63:
         raise ValueError("Invalid dscp %d" % dscp)
+
     val = dscp << 2
-    sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, val)
+    if family == socket.AF_INET:
+        sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, val)
+    elif family == socket.AF_INET6:
+        sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_TCLASS, val)
+    else:
+        raise ValueError('Invalid family %d' % family)