update from RIFT as of 696b75d2fe9fb046261b08c616f1bcf6c0b54a9b second try
[osm/SO.git] / common / python / rift / mano / dts / rpc / core.py
index dfa08bb..72016f1 100644 (file)
@@ -36,8 +36,8 @@ from ..core import DtsHandler
 class AbstractRpcHandler(DtsHandler):
     """Base class to simplify RPC implementation
     """
-    def __init__(self, log, dts, loop):
-        super().__init__(log, dts, loop)
+    def __init__(self, log, dts, loop, project=None):
+        super().__init__(log, dts, loop, project)
 
         if not asyncio.iscoroutinefunction(self.callback):
             raise ValueError('%s has to be a coroutine' % (self.callback))
@@ -61,6 +61,9 @@ class AbstractRpcHandler(DtsHandler):
     def on_prepare(self, xact_info, action, ks_path, msg):
         assert action == rwdts.QueryAction.RPC
 
+        if self.project and not self.project.rpc_check(msg, xact_info=xact_info):
+            return
+
         try:
             rpc_op = yield from self.callback(ks_path, msg)
             xact_info.respond_xpath(
@@ -76,6 +79,11 @@ class AbstractRpcHandler(DtsHandler):
 
     @asyncio.coroutine
     def register(self):
+        if self.reg:
+            self._log.warning("RPC already registered for project {}".
+                              format(self._project.name))
+            return
+
         reg_event = asyncio.Event(loop=self.loop)
 
         @asyncio.coroutine
@@ -94,6 +102,10 @@ class AbstractRpcHandler(DtsHandler):
 
         yield from reg_event.wait()
 
+    def deregister(self):
+        self.reg.deregister()
+        self.reg = None
+
     @abc.abstractmethod
     @asyncio.coroutine
     def callback(self, ks_path, msg):