Merge remote-tracking branch 'upstream/master' into gerrit-submission
[osm/RO.git] / osm_ro / db_base.py
index 2a6cd8c..7b48f43 100644 (file)
@@ -37,14 +37,7 @@ import logging
 import datetime
 from jsonschema import validate as js_v, exceptions as js_e
 
-HTTP_Bad_Request = 400
-HTTP_Unauthorized = 401 
-HTTP_Not_Found = 404 
-HTTP_Method_Not_Allowed = 405 
-HTTP_Request_Timeout = 408
-HTTP_Conflict = 409
-HTTP_Service_Unavailable = 503 
-HTTP_Internal_Server_Error = 500 
+from .http_tools import errors as httperrors
 
 def _check_valid_uuid(uuid):
     id_schema = {"type" : "string", "pattern": "^[a-fA-F0-9]{8}(-[a-fA-F0-9]{4}){3}-[a-fA-F0-9]{12}$"}
@@ -68,7 +61,7 @@ def _convert_datetime2str(var):
         for k,v in var.items():
             if type(v) is datetime.datetime:
                 var[k]= v.strftime('%Y-%m-%dT%H:%M:%S')
-            elif type(v) is dict or type(v) is list or type(v) is tuple: 
+            elif type(v) is dict or type(v) is list or type(v) is tuple:
                 _convert_datetime2str(v)
         if len(var) == 0: return True
     elif type(var) is list or type(var) is tuple:
@@ -76,7 +69,7 @@ def _convert_datetime2str(var):
             _convert_datetime2str(v)
 
 def _convert_bandwidth(data, reverse=False, logger=None):
-    '''Check the field bandwidth recursivelly and when found, it removes units and convert to number 
+    '''Check the field bandwidth recursivelly and when found, it removes units and convert to number
     It assumes that bandwidth is well formed
     Attributes:
         'data': dictionary bottle.FormsDict variable to be checked. None or empty is consideted valid
@@ -111,7 +104,7 @@ def _convert_bandwidth(data, reverse=False, logger=None):
                 _convert_bandwidth(k, reverse, logger)
 
 def _convert_str2boolean(data, items):
-    '''Check recursively the content of data, and if there is an key contained in items, convert value from string to boolean 
+    '''Check recursively the content of data, and if there is an key contained in items, convert value from string to boolean
     Done recursively
     Attributes:
         'data': dictionary variable to be checked. None or empty is considered valid
@@ -135,16 +128,15 @@ def _convert_str2boolean(data, items):
             if type(k) is dict or type(k) is tuple or type(k) is list:
                 _convert_str2boolean(k, items)
 
-class db_base_Exception(Exception):
+class db_base_Exception(httperrors.HttpMappedError):
     '''Common Exception for all database exceptions'''
-    
-    def __init__(self, message, http_code=HTTP_Bad_Request):
-        Exception.__init__(self, message)
-        self.http_code = http_code
+
+    def __init__(self, message, http_code=httperrors.Bad_Request):
+        super(db_base_Exception, self).__init__(message, http_code)
 
 class db_base():
     tables_with_created_field=()
-    
+
     def __init__(self, host=None, user=None, passwd=None, database=None, log_name='db', log_level=None):
         self.host = host
         self.user = user
@@ -155,9 +147,9 @@ class db_base():
         self.logger = logging.getLogger(log_name)
         if self.log_level:
             self.logger.setLevel( getattr(logging, log_level) )
-        
+
     def connect(self, host=None, user=None, passwd=None, database=None):
-        '''Connect to specific data base. 
+        '''Connect to specific data base.
         The first time a valid host, user, passwd and database must be provided,
         Following calls can skip this parameters
         '''
@@ -172,8 +164,16 @@ class db_base():
         except mdb.Error as e:
             raise db_base_Exception("Cannot connect to DataBase '{}' at '{}@{}' Error {}: {}".format(
                                     self.database, self.user, self.host, e.args[0], e.args[1]),
-                                    http_code = HTTP_Unauthorized )
-        
+                                    http_code = httperrors.Unauthorized )
+
+    def escape(self, value):
+        return self.con.escape(value)
+
+
+    def escape_string(self, value):
+        return self.con.escape_string(value)
+
+
     def get_db_version(self):
         ''' Obtain the database schema version.
         Return: (negative, text) if error or version 0.0 where schema_version table is missing
@@ -212,10 +212,10 @@ class db_base():
             if e[0][-5:] == "'con'":
                 self.logger.warn("while disconnecting from DB: Error %d: %s",e.args[0], e.args[1])
                 return
-            else: 
+            else:
                 raise
 
-    def _format_error(self, e, tries=1, command=None, extra=None): 
+    def _format_error(self, e, tries=1, command=None, extra=None, table=None):
         '''Creates a text error base on the produced exception
             Params:
                 e: mdb exception
@@ -227,7 +227,8 @@ class db_base():
                 HTTP error in negative, formatted error text
         '''
         if isinstance(e,AttributeError ):
-            raise db_base_Exception("DB Exception " + str(e), HTTP_Internal_Server_Error)
+            self.logger.debug(str(e), exc_info=True)
+            raise db_base_Exception("DB Exception " + str(e), httperrors.Internal_Server_Error)
         if e.args[0]==2006 or e.args[0]==2013 : #MySQL server has gone away (((or)))    Exception 2013: Lost connection to MySQL server during query
             if tries>1:
                 self.logger.warn("DB Exception '%s'. Retry", str(e))
@@ -235,32 +236,45 @@ class db_base():
                 self.connect()
                 return
             else:
-                raise db_base_Exception("Database connection timeout Try Again", HTTP_Request_Timeout)
-        
+                raise db_base_Exception("Database connection timeout Try Again", httperrors.Request_Timeout)
+
         fk=e.args[1].find("foreign key constraint fails")
         if fk>=0:
             if command=="update":
-                raise db_base_Exception("tenant_id '{}' not found.".format(extra), HTTP_Not_Found)
+                raise db_base_Exception("tenant_id '{}' not found.".format(extra), httperrors.Not_Found)
             elif command=="delete":
-                raise db_base_Exception("Resource is not free. There are {} that prevent deleting it.".format(extra), HTTP_Conflict)
+                raise db_base_Exception("Resource is not free. There are {} that prevent deleting it.".format(extra), httperrors.Conflict)
         de = e.args[1].find("Duplicate entry")
         fk = e.args[1].find("for key")
         uk = e.args[1].find("Unknown column")
         wc = e.args[1].find("in 'where clause'")
         fl = e.args[1].find("in 'field list'")
         #print de, fk, uk, wc,fl
+        table_info = ' (table `{}`)'.format(table) if table else ''
         if de>=0:
             if fk>=0: #error 1062
-                raise db_base_Exception("Value {} already in use for {}".format(e.args[1][de+15:fk], e.args[1][fk+7:]), HTTP_Conflict)
+                raise db_base_Exception(
+                    "Value {} already in use for {}{}".format(
+                        e.args[1][de+15:fk], e.args[1][fk+7:], table_info),
+                    httperrors.Conflict)
         if uk>=0:
             if wc>=0:
-                raise db_base_Exception("Field {} can not be used for filtering".format(e.args[1][uk+14:wc]), HTTP_Bad_Request)
+                raise db_base_Exception(
+                    "Field {} can not be used for filtering{}".format(
+                        e.args[1][uk+14:wc], table_info),
+                    httperrors.Bad_Request)
             if fl>=0:
-                raise db_base_Exception("Field {} does not exist".format(e.args[1][uk+14:wc]), HTTP_Bad_Request)
-        raise db_base_Exception("Database internal Error {}: {}".format(e.args[0], e.args[1]), HTTP_Internal_Server_Error)
-    
+                raise db_base_Exception(
+                    "Field {} does not exist{}".format(
+                        e.args[1][uk+14:wc], table_info),
+                    httperrors.Bad_Request)
+        raise db_base_Exception(
+                "Database internal Error{} {}: {}".format(
+                    table_info, e.args[0], e.args[1]),
+                httperrors.Internal_Server_Error)
+
     def __str2db_format(self, data):
-        '''Convert string data to database format. 
+        '''Convert string data to database format.
         If data is None it returns the 'Null' text,
         otherwise it returns the text surrounded by quotes ensuring internal quotes are escaped.
         '''
@@ -270,10 +284,10 @@ class db_base():
             return json.dumps(data)
         else:
             return json.dumps(str(data))
-    
+
     def __tuple2db_format_set(self, data):
         """Compose the needed text for a SQL SET, parameter 'data' is a pair tuple (A,B),
-        and it returns the text 'A="B"', where A is a field of a table and B is the value 
+        and it returns the text 'A="B"', where A is a field of a table and B is the value
         If B is None it returns the 'A=Null' text, without surrounding Null by quotes
         If B is not None it returns the text "A='B'" or 'A="B"' where B is surrounded by quotes,
         and it ensures internal quotes of B are escaped.
@@ -287,10 +301,10 @@ class db_base():
         elif isinstance(data[1], dict):
             if "INCREMENT" in data[1]:
                 return "{A}={A}{N:+d}".format(A=data[0], N=data[1]["INCREMENT"])
-            raise db_base_Exception("Format error for UPDATE field")
+            raise db_base_Exception("Format error for UPDATE field: {!r}".format(data[0]))
         else:
             return str(data[0]) + '=' + json.dumps(str(data[1]))
-    
+
     def __create_where(self, data, use_or=None):
         """
         Compose the needed text for a SQL WHERE, parameter 'data' can be a dict or a list of dict. By default lists are
@@ -347,9 +361,9 @@ class db_base():
         '''remove single quotes ' of any string content of data dictionary'''
         for k,v in data.items():
             if type(v) == str:
-                if "'" in v: 
+                if "'" in v:
                     data[k] = data[k].replace("'","_")
-    
+
     def _update_rows(self, table, UPDATE, WHERE, modified_time=0):
         """ Update one or several rows of a table.
         :param UPDATE: dictionary with the changes. dict keys are database columns that will be set with the dict values
@@ -369,7 +383,7 @@ class db_base():
             values += ",modified_at={:f}".format(modified_time)
         cmd= "UPDATE " + table + " SET " + values + " WHERE " + self.__create_where(WHERE)
         self.logger.debug(cmd)
-        self.cur.execute(cmd) 
+        self.cur.execute(cmd)
         return self.cur.rowcount
 
     def _new_uuid(self, root_uuid=None, used_table=None, created_time=0):
@@ -398,7 +412,7 @@ class db_base():
 
     def _new_row_internal(self, table, INSERT, add_uuid=False, root_uuid=None, created_time=0, confidential_data=False):
         ''' Add one row into a table. It DOES NOT begin or end the transaction, so self.con.cursor must be created
-        Attribute 
+        Attribute
             INSERT: dictionary with the key:value to insert
             table: table where to insert
             add_uuid: if True, it will create an uuid key entry at INSERT if not provided
@@ -411,7 +425,7 @@ class db_base():
             #create uuid if not provided
             if 'uuid' not in INSERT:
                 uuid = INSERT['uuid'] = str(myUuid.uuid1()) # create_uuid
-            else: 
+            else:
                 uuid = str(INSERT['uuid'])
         else:
             uuid=None
@@ -429,7 +443,7 @@ class db_base():
             self.cur.execute(cmd)
         #insertion
         cmd= "INSERT INTO " + table +" SET " + \
-            ",".join(map(self.__tuple2db_format_set, INSERT.iteritems() )) 
+            ",".join(map(self.__tuple2db_format_set, INSERT.iteritems() ))
         if created_time:
             cmd += ",created_at=%f" % created_time
         if confidential_data:
@@ -448,16 +462,16 @@ class db_base():
         self.cur.execute(cmd)
         rows = self.cur.fetchall()
         return rows
-    
+
     def new_row(self, table, INSERT, add_uuid=False, created_time=0, confidential_data=False):
         ''' Add one row into a table.
-        Attribute 
+        Attribute
             INSERT: dictionary with the key: value to insert
             table: table where to insert
             tenant_id: only useful for logs. If provided, logs will use this tenant_id
             add_uuid: if True, it will create an uuid key entry at INSERT if not provided
         It checks presence of uuid and add one automatically otherwise
-        Return: (result, uuid) where result can be 0 if error, or 1 if ok
+        Return: uuid
         '''
         if table in self.tables_with_created_field and created_time==0:
             created_time=time.time()
@@ -467,9 +481,9 @@ class db_base():
                 with self.con:
                     self.cur = self.con.cursor()
                     return self._new_row_internal(table, INSERT, add_uuid, None, created_time, confidential_data)
-                    
+
             except (mdb.Error, AttributeError) as e:
-                self._format_error(e, tries)
+                self._format_error(e, tries, table=table)
             tries -= 1
 
     def update_rows(self, table, UPDATE, WHERE, modified_time=0):
@@ -493,10 +507,10 @@ class db_base():
             try:
                 with self.con:
                     self.cur = self.con.cursor()
-                    return self._update_rows(table, UPDATE, WHERE)
-                    
+                    return self._update_rows(
+                        table, UPDATE, WHERE, modified_time)
             except (mdb.Error, AttributeError) as e:
-                self._format_error(e, tries)
+                self._format_error(e, tries, table=table)
             tries -= 1
 
     def _delete_row_by_id_internal(self, table, uuid):
@@ -519,7 +533,8 @@ class db_base():
                     self.cur = self.con.cursor()
                     return self._delete_row_by_id_internal(table, uuid)
             except (mdb.Error, AttributeError) as e:
-                self._format_error(e, tries, "delete", "dependencies")
+                self._format_error(
+                    e, tries, "delete", "dependencies", table=table)
             tries -= 1
 
     def delete_row(self, **sql_dict):
@@ -567,9 +582,9 @@ class db_base():
                     rows = self.cur.fetchall()
                     return rows
             except (mdb.Error, AttributeError) as e:
-                self._format_error(e, tries)
+                self._format_error(e, tries, table=table)
             tries -= 1
-    
+
     def get_rows(self, **sql_dict):
         """ Obtain rows from a table.
         :param SELECT: list or tuple of fields to retrieve) (by default all)
@@ -581,7 +596,7 @@ class db_base():
                 keys can be suffixed by >,<,<>,>=,<= so that this is used to compare key and value instead of "="
                 The special keys "OR", "AND" with a dict value is used to create a nested WHERE
             If a list, each item will be a dictionary that will be concatenated with OR
-        :param LIMIT: limit the number of obtianied entries (Optional)
+        :param LIMIT: limit the number of obtained entries (Optional)
         :param ORDER_BY:  list or tuple of fields to order, add ' DESC' to each item if inverse order is required
         :return: a list with dictionaries at each row, raises exception upon error
         """
@@ -628,10 +643,10 @@ class db_base():
         Attribute:
             table: string of table name
             uuid_name: name or uuid. If not uuid format is found, it is considered a name
-            allow_severeral: if False return ERROR if more than one row are founded 
-            error_item_text: in case of error it identifies the 'item' name for a proper output text 
+            allow_severeral: if False return ERROR if more than one row are founded
+            error_item_text: in case of error it identifies the 'item' name for a proper output text
             'WHERE_OR': dict of key:values, translated to key=value OR ... (Optional)
-            'WHERE_AND_OR: str 'AND' or 'OR'(by default) mark the priority to 'WHERE AND (WHERE_OR)' or (WHERE) OR WHERE_OR' (Optional  
+            'WHERE_AND_OR: str 'AND' or 'OR'(by default) mark the priority to 'WHERE AND (WHERE_OR)' or (WHERE) OR WHERE_OR' (Optional
         Return: if allow_several==False, a dictionary with this row, or error if no item is found or more than one is found
                 if allow_several==True, a list of dictionaries with the row or rows, error if no item is found
         '''
@@ -656,16 +671,16 @@ class db_base():
                     self.cur.execute(cmd)
                     number = self.cur.rowcount
                     if number == 0:
-                        raise db_base_Exception("No {} found with {} '{}'".format(error_item_text, what, uuid_name), http_code=HTTP_Not_Found)
+                        raise db_base_Exception("No {} found with {} '{}'".format(error_item_text, what, uuid_name), http_code=httperrors.Not_Found)
                     elif number > 1 and not allow_serveral:
-                        raise db_base_Exception("More than one {} found with {} '{}'".format(error_item_text, what, uuid_name), http_code=HTTP_Conflict)
+                        raise db_base_Exception("More than one {} found with {} '{}'".format(error_item_text, what, uuid_name), http_code=httperrors.Conflict)
                     if allow_serveral:
                         rows = self.cur.fetchall()
                     else:
                         rows = self.cur.fetchone()
                     return rows
             except (mdb.Error, AttributeError) as e:
-                self._format_error(e, tries)
+                self._format_error(e, tries, table=table)
             tries -= 1
 
     def get_uuid(self, uuid):
@@ -684,7 +699,7 @@ class db_base():
 
     def get_uuid_from_name(self, table, name):
         '''Searchs in table the name and returns the uuid
-        ''' 
+        '''
         tries = 2
         while tries:
             try:
@@ -699,6 +714,6 @@ class db_base():
                         return self.cur.rowcount, "More than one VNF with name %s found in table %s" %(name, table)
                     return self.cur.rowcount, rows[0]["uuid"]
             except (mdb.Error, AttributeError) as e:
-                self._format_error(e, tries)
+                self._format_error(e, tries, table=table)
             tries -= 1