Refactor connection task management to avoid cancels (#117)
[osm/N2VC.git] / juju / utils.py
index 9f5d63d..1d1b24e 100644 (file)
@@ -59,7 +59,40 @@ class IdQueue:
     async def get(self, id):
         value = await self._queues[id].get()
         del self._queues[id]
+        if isinstance(value, Exception):
+            raise value
         return value
 
     async def put(self, id, value):
         await self._queues[id].put(value)
+
+    async def put_all(self, value):
+        for queue in self._queues.values():
+            await queue.put(value)
+
+
+async def run_with_interrupt(task, event, loop=None):
+    """
+    Awaits a task while allowing it to be interrupted by an `asyncio.Event`.
+
+    If the task finishes without the event becoming set, the results of the
+    task will be returned.  If the event becomes set, the task will be
+    cancelled ``None`` will be returned.
+
+    :param task: Task to run
+    :param event: An `asyncio.Event` which, if set, will interrupt `task`
+        and cause it to be cancelled.
+    :param loop: Optional event loop to use other than the default.
+    """
+    loop = loop or asyncio.get_event_loop()
+    event_task = loop.create_task(event.wait())
+    done, pending = await asyncio.wait([task, event_task],
+                                       loop=loop,
+                                       return_when=asyncio.FIRST_COMPLETED)
+    for f in pending:
+        f.cancel()
+    result = [f.result() for f in done if f is not event_task]
+    if result:
+        return result[0]
+    else:
+        return None