Additional data store refactoring
authorSimo Sorce <simo@redhat.com>
Fri, 26 Sep 2014 21:41:04 +0000 (17:41 -0400)
committerPatrick Uiterwijk <puiterwijk@redhat.com>
Mon, 6 Oct 2014 18:55:08 +0000 (20:55 +0200)
Use sqlalchemy to access Sql databases, which  are the only implemented
database backends for now.
If no database type is specified we assume a sqlite3 database file path
is configured (this is backwards compatible with current configuration
statements)

Signed-off-by: Simo Sorce <simo@redhat.com>
Reviewed-by: Patrick Uiterwijk <puiterwijk@redhat.com>
ipsilon/util/data.py

index e6bca10..78fa5ab 100755 (executable)
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
-import sqlite3
 import cherrypy
 from ipsilon.util.log import Log
+from sqlalchemy import create_engine
+from sqlalchemy import MetaData, Table, Column, Text
+from sqlalchemy.sql import select
 import uuid
 
 
@@ -27,81 +29,105 @@ OPTIONS_COLUMNS = ['name', 'option', 'value']
 UNIQUE_DATA_COLUMNS = ['uuid', 'name', 'value']
 
 
+class SqlStore(Log):
+
+    def __init__(self, name):
+        if name not in cherrypy.config:
+            raise NameError('Unknown database %s' % name)
+        engine_name = cherrypy.config[name]
+        if '://' not in engine_name:
+            engine_name = 'sqlite:///' + engine_name
+        self._dbengine = create_engine(engine_name)
+
+    def engine(self):
+        return self._dbengine
+
+    def connection(self):
+        return self._dbengine.connect()
+
+
+def SqlAutotable(f):
+    def at(self, *args, **kwargs):
+        if self.autotable:
+            self.create()
+        return f(self, *args, **kwargs)
+    return at
+
+
+class SqlQuery(Log):
+
+    def __init__(self, db_obj, table, columns, autotable=True, trans=True):
+        self._db = db_obj
+        self.autotable = autotable
+        self._con = self._db.connection()
+        self._trans = self._con.begin() if trans else None
+        self._table = self._get_table(table, columns)
+
+    def _get_table(self, name, columns):
+        table = Table(name, MetaData(self._db.engine()))
+        for c in columns:
+            table.append_column(Column(c, Text()))
+        return table
+
+    def _where(self, kvfilter):
+        where = None
+        if kvfilter is not None:
+            for k in kvfilter:
+                w = self._table.columns[k] == kvfilter[k]
+                if where is None:
+                    where = w
+                else:
+                    where = where & w
+        return where
+
+    def _columns(self, columns=None):
+        cols = None
+        if columns is not None:
+            cols = []
+            for c in columns:
+                cols.append(self._table.columns[c])
+        else:
+            cols = self._table.columns
+        return cols
+
+    def rollback(self):
+        self._trans.rollback()
+
+    def commit(self):
+        self._trans.commit()
+
+    def create(self):
+        self._table.create(checkfirst=True)
+
+    def drop(self):
+        self._table.drop(checkfirst=True)
+
+    @SqlAutotable
+    def select(self, kvfilter=None, columns=None):
+        return self._con.execute(select(self._columns(columns),
+                                        self._where(kvfilter)))
+
+    @SqlAutotable
+    def insert(self, values):
+        self._con.execute(self._table.insert(values))
+
+    @SqlAutotable
+    def update(self, values, kvfilter):
+        self._con.execute(self._table.update(self._where(kvfilter), values))
+
+    @SqlAutotable
+    def delete(self, kvfilter):
+        self._con.execute(self._table.delete(self._where(kvfilter)))
+
+
 class Store(Log):
 
     def __init__(self, config_name):
-        if config_name not in cherrypy.config:
-            raise NameError('Unknown database type %s' % config_name)
-        self._dbname = cherrypy.config[config_name]
-
-    def _build_where(self, kvfilter, kvout):
-        where = ""
-        sep = "WHERE"
-        for k in kvfilter:
-            mk = "where_%s" % k
-            kvout[mk] = kvfilter[k]
-            where += "%s %s=:%s" % (sep, k, mk)
-            sep = " AND"
-        return where
+        self._db = SqlStore(config_name)
+        self._query = SqlQuery
 
-    def _build_select(self, table, kvfilter=None, kvout=None, columns=None):
-        SELECT = "SELECT %(cols)s FROM %(table)s %(where)s"
-        cols = "*"
-        if columns:
-            cols = ",".join(columns)
-        where = ""
-        if kvfilter is not None:
-            where = self._build_where(kvfilter, kvout)
-        return SELECT % {'table': table, 'cols': cols, 'where': where}
-
-    def _select(self, cursor, table, kvfilter=None, columns=None):
-        kv = dict()
-        select = self._build_select(table, kvfilter, kv, columns)
-        cursor.execute(select, kv)
-        return cursor.fetchall()
-
-    def _create(self, cursor, table, columns):
-        CREATE = "CREATE TABLE IF NOT EXISTS %(table)s(%(cols)s)"
-        cols = ",".join(columns)
-        create = CREATE % {'table': table, 'cols': cols}
-        cursor.execute(create)
-
-    def _drop(self, cursor, table):
-        cursor.execute("DROP TABLE IF EXISTS " + table)
-
-    def _update(self, cursor, table, values, kvfilter):
-        UPDATE = "UPDATE %(table)s SET %(setval)s %(where)s"
-        kv = dict()
-
-        setval = ""
-        sep = ""
-        for k in values:
-            mk = "setval_%s" % k
-            kv[mk] = values[k]
-            setval += "%s%s=:%s" % (sep, k, mk)
-            sep = " , "
-
-        where = self._build_where(kvfilter, kv)
-
-        update = UPDATE % {'table': table, 'setval': setval, 'where': where}
-        cursor.execute(update, kv)
-
-    def _insert(self, cursor, table, values):
-        INSERT = "INSERT INTO %(table)s VALUES(%(values)s)"
-        vals = ""
-        sep = ""
-        for _ in values:
-            vals += "%s?" % sep
-            sep = ","
-        insert = INSERT % {'table': table, 'values': vals}
-        cursor.execute(insert, values)
-
-    def _delete(self, cursor, table, kvfilter):
-        DELETE = "DELETE FROM %(table)s %(where)s"
-        kv = dict()
-        where = self._build_where(kvfilter, kv)
-        delete = DELETE % {'table': table, 'where': where}
-        cursor.execute(delete, kv)
+    def new_query(self, table, columns=None, autotable=True, autocommit=True):
+        return self._query(self._db, table, columns, autotable, autocommit)
 
     def _row_to_dict_tree(self, data, row):
         name = row[0]
@@ -127,109 +153,83 @@ class Store(Log):
             self._row_to_dict_tree(data, r)
         return data
 
-    def _load_data(self, table, columns, kvfilter=None):
-        con = None
+    def load_data(self, table, columns, kvfilter=None):
         rows = []
         try:
-            con = sqlite3.connect(self._dbname)
-            cur = con.cursor()
-            self._create(cur, table, columns)
-            rows = self._select(cur, table, kvfilter)
-            con.commit()
-        except sqlite3.Error, e:
-            if con:
-                con.rollback()
+            q = self._query(self._db, table, columns, trans=False)
+            rows = q.select(kvfilter)
+        except Exception, e:  # pylint: disable=broad-except
             self.error("Failed to load data for table %s: [%s]" % (table, e))
-        finally:
-            if con:
-                con.close()
-
         return self._rows_to_dict_tree(rows)
 
     def load_config(self):
         table = 'config'
         columns = ['name', 'value']
-        return self._load_data(table, columns)
+        return self.load_data(table, columns)
 
     def load_options(self, table, name=None):
         kvfilter = dict()
         if name:
             kvfilter['name'] = name
-        options = self._load_data(table, OPTIONS_COLUMNS, kvfilter)
+        options = self.load_data(table, OPTIONS_COLUMNS, kvfilter)
         if name and name in options:
             return options[name]
         return options
 
     def save_options(self, table, name, options):
         curvals = dict()
-        con = None
+        q = None
         try:
-            con = sqlite3.connect(self._dbname)
-            cur = con.cursor()
-            self._create(cur, table, OPTIONS_COLUMNS)
-            rows = self._select(cur, table, {'name': name},
-                                ['option', 'value'])
+            q = self._query(self._db, table, OPTIONS_COLUMNS)
+            rows = q.select({'name': name}, ['option', 'value'])
             for row in rows:
                 curvals[row[0]] = row[1]
 
             for opt in options:
                 if opt in curvals:
-                    self._update(cur, table,
-                                 {'value': options[opt]},
-                                 {'name': name, 'option': opt})
+                    q.update({'value': options[opt]},
+                             {'name': name, 'option': opt})
                 else:
-                    self._insert(cur, table, (name, opt, options[opt]))
+                    q.insert((name, opt, options[opt]))
 
-            con.commit()
-        except sqlite3.Error, e:
-            if con:
-                con.rollback()
-            self.error("Failed to store config: [%s]" % e)
+            q.commit()
+        except Exception, e:  # pylint: disable=broad-except
+            if q:
+                q.rollback()
+            self.error("Failed to save options: [%s]" % e)
             raise
-        finally:
-            if con:
-                con.close()
 
     def delete_options(self, table, name, options=None):
         kvfilter = {'name': name}
+        q = None
         try:
-            con = sqlite3.connect(self._dbname)
-            cur = con.cursor()
-            self._create(cur, table, OPTIONS_COLUMNS)
+            q = self._query(self._db, table, OPTIONS_COLUMNS)
             if options is None:
-                self._delete(cur, table, kvfilter)
+                q.delete(kvfilter)
             else:
                 for opt in options:
                     kvfilter['option'] = opt
-                    self._delete(cur, table, kvfilter)
-            con.commit()
-        except sqlite3.Error, e:
-            if con:
-                con.rollback()
+                    q.delete(kvfilter)
+            q.commit()
+        except Exception, e:  # pylint: disable=broad-except
+            if q:
+                q.rollback()
             self.error("Failed to delete from %s: [%s]" % (table, e))
             raise
-        finally:
-            if con:
-                con.close()
 
     def new_unique_data(self, table, data):
-        con = None
+        newid = str(uuid.uuid4())
+        q = None
         try:
-            con = sqlite3.connect(self._dbname)
-            cur = con.cursor()
-            self._create(cur, table, UNIQUE_DATA_COLUMNS)
-            newid = str(uuid.uuid4())
+            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
             for name in data:
-                self._insert(cur, table, (newid, name, data[name]))
-            con.commit()
-        except sqlite3.Error, e:
-            if con:
-                con.rollback()
-            cherrypy.log.error("Failed to store %s data: [%s]" % (table, e))
+                q.insert((newid, name, data[name]))
+            q.commit()
+        except Exception, e:  # pylint: disable=broad-except
+            if q:
+                q.rollback()
+            self.error("Failed to store %s data: [%s]" % (table, e))
             raise
-        finally:
-            if con:
-                con.close()
         return newid
 
     def get_unique_data(self, table, uuidval=None, name=None, value=None):
@@ -240,68 +240,51 @@ class Store(Log):
             kvfilter['name'] = name
         if value:
             kvfilter['value'] = value
-        return self._load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
+        return self.load_data(table, UNIQUE_DATA_COLUMNS, kvfilter)
 
     def save_unique_data(self, table, data):
-        curvals = dict()
-        con = None
+        q = None
         try:
-            con = sqlite3.connect(self._dbname)
-            cur = con.cursor()
-            self._create(cur, table, UNIQUE_DATA_COLUMNS)
+            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
             for uid in data:
                 curvals = dict()
-                rows = self._select(cur, table, {'uuid': uid},
-                                    ['name', 'value'])
+                rows = q.select({'uuid': uid}, ['name', 'value'])
                 for r in rows:
                     curvals[r[0]] = r[1]
 
                 datum = data[uid]
                 for name in datum:
                     if name in curvals:
-                        self._update(cur, table,
-                                     {'value': datum[name]},
-                                     {'uuid': uid, 'name': name})
+                        q.update({'value': datum[name]},
+                                 {'uuid': uid, 'name': name})
                     else:
-                        self._insert(cur, table, (uid, name, datum[name]))
+                        q.insert((uid, name, datum[name]))
 
-            con.commit()
-        except sqlite3.Error, e:
-            if con:
-                con.rollback()
+            q.commit()
+        except Exception, e:  # pylint: disable=broad-except
+            if q:
+                q.rollback()
             self.error("Failed to store data in %s: [%s]" % (table, e))
             raise
-        finally:
-            if con:
-                con.close()
 
     def del_unique_data(self, table, uuidval):
         kvfilter = {'uuid': uuidval}
-        con = None
         try:
-            con = sqlite3.connect(self._dbname)
-            cur = con.cursor()
-            self._delete(cur, table, kvfilter)
-        except sqlite3.Error, e:
+            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS, trans=False)
+            q.delete(kvfilter)
+        except Exception, e:  # pylint: disable=broad-except
             self.error("Failed to delete data from %s: [%s]" % (table, e))
-        finally:
-            if con:
-                con.close()
 
     def reset_data(self, table):
         try:
-            con = sqlite3.connect(self._dbname)
-            cur = con.cursor()
-            self._drop(cur, table)
-            self._create(cur, table, UNIQUE_DATA_COLUMNS)
-            con.commit()
-        except sqlite3.Error, e:
-            if con:
-                con.rollback()
+            q = self._query(self._db, table, UNIQUE_DATA_COLUMNS)
+            q.drop()
+            q.create()
+            q.commit()
+        except Exception, e:  # pylint: disable=broad-except
+            if q:
+                q.rollback()
             self.error("Failed to erase all data from %s: [%s]" % (table, e))
-        finally:
-            if con:
-                con.close()
 
 
 class AdminStore(Store):