| import asyncio |
| import os |
| from collections import defaultdict |
| from functools import partial |
| from pathlib import Path |
| import base64 |
| from pyasn1.type import univ, char |
| from pyasn1.codec.der.encoder import encode |
| |
| |
| async def execute_process(*cmd, log=None, loop=None): |
| ''' |
| Wrapper around asyncio.create_subprocess_exec. |
| |
| ''' |
| p = await asyncio.create_subprocess_exec( |
| *cmd, |
| stdin=asyncio.subprocess.PIPE, |
| stdout=asyncio.subprocess.PIPE, |
| stderr=asyncio.subprocess.PIPE, |
| loop=loop) |
| stdout, stderr = await p.communicate() |
| if log: |
| log.debug("Exec %s -> %d", cmd, p.returncode) |
| if stdout: |
| log.debug(stdout.decode('utf-8')) |
| if stderr: |
| log.debug(stderr.decode('utf-8')) |
| return p.returncode == 0 |
| |
| |
| def _read_ssh_key(): |
| ''' |
| Inner function for read_ssh_key, suitable for passing to our |
| Executor. |
| |
| ''' |
| default_data_dir = Path(Path.home(), ".local", "share", "juju") |
| juju_data = os.environ.get("JUJU_DATA", default_data_dir) |
| ssh_key_path = Path(juju_data, 'ssh', 'juju_id_rsa.pub') |
| with ssh_key_path.open('r') as ssh_key_file: |
| ssh_key = ssh_key_file.readlines()[0].strip() |
| return ssh_key |
| |
| |
| async def read_ssh_key(loop): |
| ''' |
| Attempt to read the local juju admin's public ssh key, so that it |
| can be passed on to a model. |
| |
| ''' |
| loop = loop or asyncio.get_event_loop() |
| return await loop.run_in_executor(None, _read_ssh_key) |
| |
| |
| class IdQueue: |
| """ |
| Wrapper around asyncio.Queue that maintains a separate queue for each ID. |
| """ |
| def __init__(self, maxsize=0, *, loop=None): |
| self._queues = defaultdict(partial(asyncio.Queue, maxsize, loop=loop)) |
| |
| async def get(self, id): |
| value = await self._queues[id].get() |
| del self._queues[id] |
| if isinstance(value, Exception): |
| raise value |
| return value |
| |
| async def put(self, id, value): |
| await self._queues[id].put(value) |
| |
| async def put_all(self, value): |
| for queue in self._queues.values(): |
| await queue.put(value) |
| |
| |
| async def block_until(*conditions, timeout=None, wait_period=0.5, loop=None): |
| """Return only after all conditions are true. |
| |
| """ |
| async def _block(): |
| while not all(c() for c in conditions): |
| await asyncio.sleep(wait_period, loop=loop) |
| await asyncio.wait_for(_block(), timeout, loop=loop) |
| |
| |
| async def run_with_interrupt(task, *events, loop=None): |
| """ |
| Awaits a task while allowing it to be interrupted by one or more |
| `asyncio.Event`s. |
| |
| If the task finishes without the events becoming set, the results of the |
| task will be returned. If the event become set, the task will be cancelled |
| ``None`` will be returned. |
| |
| :param task: Task to run |
| :param events: One or more `asyncio.Event`s which, if set, will interrupt |
| `task` and cause it to be cancelled. |
| :param loop: Optional event loop to use other than the default. |
| """ |
| loop = loop or asyncio.get_event_loop() |
| task = asyncio.ensure_future(task, loop=loop) |
| event_tasks = [loop.create_task(event.wait()) for event in events] |
| done, pending = await asyncio.wait([task] + event_tasks, |
| loop=loop, |
| return_when=asyncio.FIRST_COMPLETED) |
| for f in pending: |
| f.cancel() # cancel unfinished tasks |
| for f in done: |
| f.exception() # prevent "exception was not retrieved" errors |
| if task in done: |
| return task.result() # may raise exception |
| else: |
| return None |
| |
| |
| class Addrs(univ.SequenceOf): |
| componentType = char.PrintableString() |
| |
| |
| class RegistrationInfo(univ.Sequence): |
| """ |
| ASN.1 representation of: |
| |
| type RegistrationInfo struct { |
| User string |
| |
| Addrs []string |
| |
| SecretKey []byte |
| |
| ControllerName string |
| } |
| """ |
| pass |
| |
| |
| def generate_user_controller_access_token(username, controller_endpoints, secret_key, controller_name): |
| """" Implement in python what is currently done in GO |
| https://github.com/juju/juju/blob/a5ab92ec9b7f5da3678d9ac603fe52d45af24412/cmd/juju/user/utils.go#L16 |
| |
| :param username: name of the user to register |
| :param controller_endpoints: juju controller endpoints list in the format <ip>:<port> |
| :param secret_key: base64 encoded string of the secret-key generated by juju |
| :param controller_name: name of the controller to register to. |
| """ |
| |
| # Secret key is returned as base64 encoded string in: |
| # https://websockets.readthedocs.io/en/stable/_modules/websockets/protocol.html#WebSocketCommonProtocol.recv |
| # Deconding it before marshalling into the ASN.1 message |
| secret_key = base64.b64decode(secret_key) |
| addr = Addrs() |
| for endpoint in controller_endpoints: |
| addr.append(endpoint) |
| |
| registration_string = RegistrationInfo() |
| registration_string.setComponentByPosition(0, char.PrintableString(username)) |
| registration_string.setComponentByPosition(1, addr) |
| registration_string.setComponentByPosition(2, univ.OctetString(secret_key)) |
| registration_string.setComponentByPosition(3, char.PrintableString(controller_name)) |
| registration_string = encode(registration_string) |
| remainder = len(registration_string) % 3 |
| registration_string += b"\0" * (3 - remainder) |
| return base64.urlsafe_b64encode(registration_string) |