X-Git-Url: https://osm.etsi.org/gitweb/?p=osm%2FN2VC.git;a=blobdiff_plain;f=modules%2Flibjuju%2Fjuju%2Futils.py;h=1038ed1ee578c594a1b6161fe9cefc9821f0cee9;hp=3565fd630c6a3dc3839fe3f382bdf5e1aab0fdc2;hb=d4ec83bbe1d74a7432ea472dfe5b748d1611bde4;hpb=c3e6c2ec9a1fddfc8e9bd31509b366e633b6d99e diff --git a/modules/libjuju/juju/utils.py b/modules/libjuju/juju/utils.py index 3565fd6..1038ed1 100644 --- a/modules/libjuju/juju/utils.py +++ b/modules/libjuju/juju/utils.py @@ -3,6 +3,9 @@ import os from collections import defaultdict from functools import partial from pathlib import Path +import base64 +from pyasn1.type import univ, char +from pyasn1.codec.der.encoder import encode async def execute_process(*cmd, log=None, loop=None): @@ -82,32 +85,81 @@ async def block_until(*conditions, timeout=None, wait_period=0.5, loop=None): await asyncio.wait_for(_block(), timeout, loop=loop) -async def run_with_interrupt(task, event, loop=None): +async def run_with_interrupt(task, *events, loop=None): """ - Awaits a task while allowing it to be interrupted by an `asyncio.Event`. + Awaits a task while allowing it to be interrupted by one or more + `asyncio.Event`s. - If the task finishes without the event becoming set, the results of the - task will be returned. If the event becomes set, the task will be - cancelled ``None`` will be returned. + If the task finishes without the events becoming set, the results of the + task will be returned. If the event become set, the task will be cancelled + ``None`` will be returned. :param task: Task to run - :param event: An `asyncio.Event` which, if set, will interrupt `task` - and cause it to be cancelled. + :param events: One or more `asyncio.Event`s which, if set, will interrupt + `task` and cause it to be cancelled. :param loop: Optional event loop to use other than the default. """ loop = loop or asyncio.get_event_loop() - event_task = loop.create_task(event.wait()) - done, pending = await asyncio.wait([task, event_task], + task = asyncio.ensure_future(task, loop=loop) + event_tasks = [loop.create_task(event.wait()) for event in events] + done, pending = await asyncio.wait([task] + event_tasks, loop=loop, return_when=asyncio.FIRST_COMPLETED) for f in pending: - f.cancel() - exception = [f.exception() for f in done - if f is not event_task and f.exception()] - if exception: - raise exception[0] - result = [f.result() for f in done if f is not event_task] - if result: - return result[0] + f.cancel() # cancel unfinished tasks + for f in done: + f.exception() # prevent "exception was not retrieved" errors + if task in done: + return task.result() # may raise exception else: return None + + +class Addrs(univ.SequenceOf): + componentType = char.PrintableString() + + +class RegistrationInfo(univ.Sequence): + """ + ASN.1 representation of: + + type RegistrationInfo struct { + User string + + Addrs []string + + SecretKey []byte + + ControllerName string + } + """ + pass + + +def generate_user_controller_access_token(username, controller_endpoints, secret_key, controller_name): + """" Implement in python what is currently done in GO + https://github.com/juju/juju/blob/a5ab92ec9b7f5da3678d9ac603fe52d45af24412/cmd/juju/user/utils.go#L16 + + :param username: name of the user to register + :param controller_endpoints: juju controller endpoints list in the format : + :param secret_key: base64 encoded string of the secret-key generated by juju + :param controller_name: name of the controller to register to. + """ + + # Secret key is returned as base64 encoded string in: + # https://websockets.readthedocs.io/en/stable/_modules/websockets/protocol.html#WebSocketCommonProtocol.recv + # Deconding it before marshalling into the ASN.1 message + secret_key = base64.b64decode(secret_key) + addr = Addrs() + for endpoint in controller_endpoints: + addr.append(endpoint) + + registration_string = RegistrationInfo() + registration_string.setComponentByPosition(0, char.PrintableString(username)) + registration_string.setComponentByPosition(1, addr) + registration_string.setComponentByPosition(2, univ.OctetString(secret_key)) + registration_string.setComponentByPosition(3, char.PrintableString(controller_name)) + registration_string = encode(registration_string) + remainder = len(registration_string) % 3 + registration_string += b"\0" * (3 - remainder) + return base64.urlsafe_b64encode(registration_string)