Add region to nets and allow ovelapping vlan at each region
[osm/openvim.git] / osm_openvim / vim_db.py
index c34160d..f69c305 100644 (file)
@@ -56,8 +56,7 @@ class vim_db():
         ''' 
         #initialization
         self.net_vlan_range = vlan_range
-        self.net_vlan_usedlist = None
-        self.net_vlan_lastused = self.net_vlan_range[0] -1
+        self.vlan_config = {}
         self.debug=debug
         if logger_name:
             self.logger_name = logger_name
@@ -176,43 +175,49 @@ class vim_db():
         else:
             return json.dumps(out)
     
-    def __get_used_net_vlan(self):
+    def __get_used_net_vlan(self, region=None):
         #get used from database if needed
+        vlan_region = self.vlan_config[region]
         try:
-            cmd = "SELECT vlan FROM nets WHERE vlan>='%s' ORDER BY vlan LIMIT 25" % self.net_vlan_lastused
+            cmd = "SELECT vlan FROM nets WHERE vlan>='{}' and region{} ORDER BY vlan LIMIT 25".format(
+                vlan_region["lastused"], "='"+region+"'" if region else " is NULL")
             with self.con:
                 self.cur = self.con.cursor()
                 self.logger.debug(cmd)
                 self.cur.execute(cmd)
                 vlan_tuple = self.cur.fetchall()
-                #convert a tuple of tuples in a list of numbers
-                self.net_vlan_usedlist = []
+                # convert a tuple of tuples in a list of numbers
+                vlan_region["usedlist"] = []
                 for k in vlan_tuple:
-                    self.net_vlan_usedlist.append(k[0])
-            return 0
+                    vlan_region["usedlist"].append(k[0])
         except (mdb.Error, AttributeError) as e:
             return self.format_error(e, "get_free_net_vlan", cmd)
     
-    def get_free_net_vlan(self):
+    def get_free_net_vlan(self, region=None):
         '''obtain a vlan not used in any net'''
-        
+        if region not in self.vlan_config:
+            self.vlan_config[region] = {
+                "usedlist": None,
+                "lastused": self.net_vlan_range[0] - 1
+            }
+        vlan_region = self.vlan_config[region]
+
         while True:
-            self.logger.debug("net_vlan_lastused:%d  net_vlan_range:%d-%d  net_vlan_usedlist:%s", 
-                            self.net_vlan_lastused, self.net_vlan_range[0], self.net_vlan_range[1], str(self.net_vlan_usedlist))
-            self.net_vlan_lastused += 1
-            if self.net_vlan_lastused ==  self.net_vlan_range[1]:
-                #start from the begining
-                self.net_vlan_lastused =  self.net_vlan_range[0]
-                self.net_vlan_usedlist = None
-            if self.net_vlan_usedlist is None \
-            or (len(self.net_vlan_usedlist)>0 and self.net_vlan_lastused >= self.net_vlan_usedlist[-1] and len(self.net_vlan_usedlist)==25):
-                r = self.__get_used_net_vlan()
-                if r<0: return r
-                self.logger.debug("new net_vlan_usedlist %s", str(self.net_vlan_usedlist))
-            if self.net_vlan_lastused in self.net_vlan_usedlist:
+            self.logger.debug("get_free_net_vlan() region[{}]={}, net_vlan_range:{}-{} ".format(region, vlan_region,
+                            self.net_vlan_range[0], self.net_vlan_range[1]))
+            vlan_region["lastused"] += 1
+            if vlan_region["lastused"] ==  self.net_vlan_range[1]:
+                # start from the begining
+                vlan_region["lastused"] =  self.net_vlan_range[0]
+                vlan_region["usedlist"] = None
+            if vlan_region["usedlist"] is None or \
+                    (len(vlan_region["usedlist"])==25 and vlan_region["lastused"] >= vlan_region["usedlist"][-1]):
+                self.__get_used_net_vlan(region)
+                self.logger.debug("new net_vlan_usedlist %s", str(vlan_region["usedlist"]))
+            if vlan_region["lastused"] in vlan_region["usedlist"]:
                 continue
             else:
-                return self.net_vlan_lastused
+                return vlan_region["lastused"]
                 
     def get_table(self, **sql_dict):
         ''' Obtain rows from a table.