Revert "Remove vendored libjuju"
[osm/N2VC.git] / modules / libjuju / juju / utils.py
diff --git a/modules/libjuju/juju/utils.py b/modules/libjuju/juju/utils.py
new file mode 100644 (file)
index 0000000..1038ed1
--- /dev/null
@@ -0,0 +1,165 @@
+import asyncio
+import os
+from collections import defaultdict
+from functools import partial
+from pathlib import Path
+import base64
+from pyasn1.type import univ, char
+from pyasn1.codec.der.encoder import encode
+
+
+async def execute_process(*cmd, log=None, loop=None):
+    '''
+    Wrapper around asyncio.create_subprocess_exec.
+
+    '''
+    p = await asyncio.create_subprocess_exec(
+        *cmd,
+        stdin=asyncio.subprocess.PIPE,
+        stdout=asyncio.subprocess.PIPE,
+        stderr=asyncio.subprocess.PIPE,
+        loop=loop)
+    stdout, stderr = await p.communicate()
+    if log:
+        log.debug("Exec %s -> %d", cmd, p.returncode)
+        if stdout:
+            log.debug(stdout.decode('utf-8'))
+        if stderr:
+            log.debug(stderr.decode('utf-8'))
+    return p.returncode == 0
+
+
+def _read_ssh_key():
+    '''
+    Inner function for read_ssh_key, suitable for passing to our
+    Executor.
+
+    '''
+    default_data_dir = Path(Path.home(), ".local", "share", "juju")
+    juju_data = os.environ.get("JUJU_DATA", default_data_dir)
+    ssh_key_path = Path(juju_data, 'ssh', 'juju_id_rsa.pub')
+    with ssh_key_path.open('r') as ssh_key_file:
+        ssh_key = ssh_key_file.readlines()[0].strip()
+    return ssh_key
+
+
+async def read_ssh_key(loop):
+    '''
+    Attempt to read the local juju admin's public ssh key, so that it
+    can be passed on to a model.
+
+    '''
+    loop = loop or asyncio.get_event_loop()
+    return await loop.run_in_executor(None, _read_ssh_key)
+
+
+class IdQueue:
+    """
+    Wrapper around asyncio.Queue that maintains a separate queue for each ID.
+    """
+    def __init__(self, maxsize=0, *, loop=None):
+        self._queues = defaultdict(partial(asyncio.Queue, maxsize, loop=loop))
+
+    async def get(self, id):
+        value = await self._queues[id].get()
+        del self._queues[id]
+        if isinstance(value, Exception):
+            raise value
+        return value
+
+    async def put(self, id, value):
+        await self._queues[id].put(value)
+
+    async def put_all(self, value):
+        for queue in self._queues.values():
+            await queue.put(value)
+
+
+async def block_until(*conditions, timeout=None, wait_period=0.5, loop=None):
+    """Return only after all conditions are true.
+
+    """
+    async def _block():
+        while not all(c() for c in conditions):
+            await asyncio.sleep(wait_period, loop=loop)
+    await asyncio.wait_for(_block(), timeout, loop=loop)
+
+
+async def run_with_interrupt(task, *events, loop=None):
+    """
+    Awaits a task while allowing it to be interrupted by one or more
+    `asyncio.Event`s.
+
+    If the task finishes without the events becoming set, the results of the
+    task will be returned.  If the event become set, the task will be cancelled
+    ``None`` will be returned.
+
+    :param task: Task to run
+    :param events: One or more `asyncio.Event`s which, if set, will interrupt
+        `task` and cause it to be cancelled.
+    :param loop: Optional event loop to use other than the default.
+    """
+    loop = loop or asyncio.get_event_loop()
+    task = asyncio.ensure_future(task, loop=loop)
+    event_tasks = [loop.create_task(event.wait()) for event in events]
+    done, pending = await asyncio.wait([task] + event_tasks,
+                                       loop=loop,
+                                       return_when=asyncio.FIRST_COMPLETED)
+    for f in pending:
+        f.cancel()  # cancel unfinished tasks
+    for f in done:
+        f.exception()  # prevent "exception was not retrieved" errors
+    if task in done:
+        return task.result()  # may raise exception
+    else:
+        return None
+
+
+class Addrs(univ.SequenceOf):
+    componentType = char.PrintableString()
+
+
+class RegistrationInfo(univ.Sequence):
+    """
+    ASN.1 representation of:
+
+    type RegistrationInfo struct {
+    User string
+
+        Addrs []string
+
+        SecretKey []byte
+
+        ControllerName string
+    }
+    """
+    pass
+
+
+def generate_user_controller_access_token(username, controller_endpoints, secret_key, controller_name):
+    """" Implement in python what is currently done in GO
+    https://github.com/juju/juju/blob/a5ab92ec9b7f5da3678d9ac603fe52d45af24412/cmd/juju/user/utils.go#L16
+
+    :param username: name of the user to register
+    :param controller_endpoints: juju controller endpoints list in the format <ip>:<port>
+    :param secret_key: base64 encoded string of the secret-key generated by juju
+    :param controller_name: name of the controller to register to.
+    """
+
+    # Secret key is returned as base64 encoded string in:
+    # https://websockets.readthedocs.io/en/stable/_modules/websockets/protocol.html#WebSocketCommonProtocol.recv
+    # Deconding it before marshalling into the ASN.1 message
+    secret_key = base64.b64decode(secret_key)
+    addr = Addrs()
+    for endpoint in controller_endpoints:
+        addr.append(endpoint)
+
+    registration_string = RegistrationInfo()
+    registration_string.setComponentByPosition(0, char.PrintableString(username))
+    registration_string.setComponentByPosition(1, addr)
+    registration_string.setComponentByPosition(2, univ.OctetString(secret_key))
+    registration_string.setComponentByPosition(3, char.PrintableString(controller_name))
+    registration_string = encode(registration_string)
+    remainder = len(registration_string) % 3
+    registration_string += b"\0" * (3 - remainder)
+    return base64.urlsafe_b64encode(registration_string)