from collections import defaultdict
from functools import partial
from pathlib import Path
+import base64
+from pyasn1.type import univ, char
+from pyasn1.codec.der.encoder import encode
async def execute_process(*cmd, log=None, loop=None):
await asyncio.wait_for(_block(), timeout, loop=loop)
-async def run_with_interrupt(task, event, loop=None):
+async def run_with_interrupt(task, *events, loop=None):
"""
- Awaits a task while allowing it to be interrupted by an `asyncio.Event`.
+ Awaits a task while allowing it to be interrupted by one or more
+ `asyncio.Event`s.
- If the task finishes without the event becoming set, the results of the
- task will be returned. If the event becomes set, the task will be
- cancelled ``None`` will be returned.
+ If the task finishes without the events becoming set, the results of the
+ task will be returned. If the event become set, the task will be cancelled
+ ``None`` will be returned.
:param task: Task to run
- :param event: An `asyncio.Event` which, if set, will interrupt `task`
- and cause it to be cancelled.
+ :param events: One or more `asyncio.Event`s which, if set, will interrupt
+ `task` and cause it to be cancelled.
:param loop: Optional event loop to use other than the default.
"""
loop = loop or asyncio.get_event_loop()
- event_task = loop.create_task(event.wait())
- done, pending = await asyncio.wait([task, event_task],
+ task = asyncio.ensure_future(task, loop=loop)
+ event_tasks = [loop.create_task(event.wait()) for event in events]
+ done, pending = await asyncio.wait([task] + event_tasks,
loop=loop,
return_when=asyncio.FIRST_COMPLETED)
for f in pending:
- f.cancel()
- exception = [f.exception() for f in done
- if f is not event_task and f.exception()]
- if exception:
- raise exception[0]
- result = [f.result() for f in done if f is not event_task]
- if result:
- return result[0]
+ f.cancel() # cancel unfinished tasks
+ for f in done:
+ f.exception() # prevent "exception was not retrieved" errors
+ if task in done:
+ return task.result() # may raise exception
else:
return None
+
+
+class Addrs(univ.SequenceOf):
+ componentType = char.PrintableString()
+
+
+class RegistrationInfo(univ.Sequence):
+ """
+ ASN.1 representation of:
+
+ type RegistrationInfo struct {
+ User string
+
+ Addrs []string
+
+ SecretKey []byte
+
+ ControllerName string
+ }
+ """
+ pass
+
+
+def generate_user_controller_access_token(username, controller_endpoints, secret_key, controller_name):
+ """" Implement in python what is currently done in GO
+ https://github.com/juju/juju/blob/a5ab92ec9b7f5da3678d9ac603fe52d45af24412/cmd/juju/user/utils.go#L16
+
+ :param username: name of the user to register
+ :param controller_endpoints: juju controller endpoints list in the format <ip>:<port>
+ :param secret_key: base64 encoded string of the secret-key generated by juju
+ :param controller_name: name of the controller to register to.
+ """
+
+ # Secret key is returned as base64 encoded string in:
+ # https://websockets.readthedocs.io/en/stable/_modules/websockets/protocol.html#WebSocketCommonProtocol.recv
+ # Deconding it before marshalling into the ASN.1 message
+ secret_key = base64.b64decode(secret_key)
+ addr = Addrs()
+ for endpoint in controller_endpoints:
+ addr.append(endpoint)
+
+ registration_string = RegistrationInfo()
+ registration_string.setComponentByPosition(0, char.PrintableString(username))
+ registration_string.setComponentByPosition(1, addr)
+ registration_string.setComponentByPosition(2, univ.OctetString(secret_key))
+ registration_string.setComponentByPosition(3, char.PrintableString(controller_name))
+ registration_string = encode(registration_string)
+ remainder = len(registration_string) % 3
+ registration_string += b"\0" * (3 - remainder)
+ return base64.urlsafe_b64encode(registration_string)