Refactor connection task management to avoid cancels (#117)
[osm/N2VC.git] / juju / utils.py
index f4db66e..1d1b24e 100644 (file)
@@ -69,3 +69,30 @@ class IdQueue:
     async def put_all(self, value):
         for queue in self._queues.values():
             await queue.put(value)
+
+
+async def run_with_interrupt(task, event, loop=None):
+    """
+    Awaits a task while allowing it to be interrupted by an `asyncio.Event`.
+
+    If the task finishes without the event becoming set, the results of the
+    task will be returned.  If the event becomes set, the task will be
+    cancelled ``None`` will be returned.
+
+    :param task: Task to run
+    :param event: An `asyncio.Event` which, if set, will interrupt `task`
+        and cause it to be cancelled.
+    :param loop: Optional event loop to use other than the default.
+    """
+    loop = loop or asyncio.get_event_loop()
+    event_task = loop.create_task(event.wait())
+    done, pending = await asyncio.wait([task, event_task],
+                                       loop=loop,
+                                       return_when=asyncio.FIRST_COMPLETED)
+    for f in pending:
+        f.cancel()
+    result = [f.result() for f in done if f is not event_task]
+    if result:
+        return result[0]
+    else:
+        return None