X-Git-Url: https://osm.etsi.org/gitweb/?a=blobdiff_plain;f=juju%2Fmodel.py;h=bd8709a3bfa057a68319d0a5d732b8c54a55d8d7;hb=c50c361a8b9a3bbf1a33f5659e492b481f065cd2;hp=04fb2d47aea5f17e42af9102e4cf04e1b4182c9a;hpb=e6b7b9d9d091c1ef32236370183ce2ff22d01956;p=osm%2FN2VC.git diff --git a/juju/model.py b/juju/model.py index 04fb2d4..bd8709a 100644 --- a/juju/model.py +++ b/juju/model.py @@ -1,5 +1,7 @@ import asyncio +import base64 import collections +import hashlib import json import logging import os @@ -12,13 +14,15 @@ from concurrent.futures import CancelledError from functools import partial from pathlib import Path +import websockets import yaml import theblues.charmstore import theblues.errors +from . import tag, utils from .client import client -from .client import watcher from .client import connection +from .client.client import ConfigValue from .constraints import parse as parse_constraints, normalize_key from .delta import get_entity_delta from .delta import get_entity_class @@ -74,6 +78,9 @@ class _Observer(object): class ModelObserver(object): + """ + Base class for creating observers that react to changes in a model. + """ async def __call__(self, delta, old, new, model): handler_name = 'on_{}_{}'.format(delta.entity, delta.type) method = getattr(self, handler_name, self.on_change) @@ -82,6 +89,8 @@ class ModelObserver(object): async def on_change(self, delta, old, new, model): """Generic model-change handler. + This should be overridden in a subclass. + :param delta: :class:`juju.client.overrides.Delta` :param old: :class:`juju.model.ModelEntity` :param new: :class:`juju.model.ModelEntity` @@ -230,7 +239,14 @@ class ModelEntity(object): model. """ - return self.safe_data[name] + try: + return self.safe_data[name] + except KeyError: + name = name.replace('_', '-') + if name in self.safe_data: + return self.safe_data[name] + else: + raise def __bool__(self): return bool(self.data) @@ -365,22 +381,39 @@ class ModelEntity(object): class Model(object): - def __init__(self, loop=None): + """ + The main API for interacting with a Juju model. + """ + def __init__(self, loop=None, + max_frame_size=connection.Connection.DEFAULT_FRAME_SIZE): """Instantiate a new connected Model. :param loop: an asyncio event loop + :param max_frame_size: See + `juju.client.connection.Connection.MAX_FRAME_SIZE` """ self.loop = loop or asyncio.get_event_loop() + self.max_frame_size = max_frame_size self.connection = None self.observers = weakref.WeakValueDictionary() self.state = ModelState(self) self.info = None - self._watcher_task = None - self._watch_shutdown = asyncio.Event(loop=self.loop) + self._watch_stopping = asyncio.Event(loop=self.loop) + self._watch_stopped = asyncio.Event(loop=self.loop) self._watch_received = asyncio.Event(loop=self.loop) self._charmstore = CharmStore(self.loop) + async def __aenter__(self): + await self.connect_current() + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.disconnect() + + if exc_type is not None: + return False + async def connect(self, *args, **kw): """Connect to an arbitrary Juju model. @@ -389,6 +422,8 @@ class Model(object): """ if 'loop' not in kw: kw['loop'] = self.loop + if 'max_frame_size' not in kw: + kw['max_frame_size'] = self.max_frame_size self.connection = await connection.Connection.connect(*args, **kw) await self._after_connect() @@ -397,7 +432,7 @@ class Model(object): """ self.connection = await connection.Connection.connect_current( - self.loop) + self.loop, max_frame_size=self.max_frame_size) await self._after_connect() async def connect_model(self, model_name): @@ -406,8 +441,8 @@ class Model(object): :param model_name: Format [controller:][user/]model """ - self.connection = await connection.Connection.connect_model(model_name, - self.loop) + self.connection = await connection.Connection.connect_model( + model_name, self.loop, self.max_frame_size) await self._after_connect() async def _after_connect(self): @@ -422,9 +457,10 @@ class Model(object): """Shut down the watcher task and close websockets. """ - self._stop_watching() if self.connection and self.connection.is_open: - await self._watch_shutdown.wait() + log.debug('Stopping watcher task') + self._watch_stopping.set() + await self._watch_stopped.wait() log.debug('Closing model connection') await self.connection.close() self.connection = None @@ -516,6 +552,8 @@ class Model(object): """ async def _block(): while not all(c() for c in conditions): + if not (self.connection and self.connection.is_open): + raise websockets.ConnectionClosed(1006, 'no reason') await asyncio.sleep(wait_period, loop=self.loop) await asyncio.wait_for(_block(), timeout, loop=self.loop) @@ -556,8 +594,7 @@ class Model(object): explicit call to this method. """ - facade = client.ClientFacade() - facade.connect(self.connection) + facade = client.ClientFacade.from_connection(self.connection) self.info = await facade.ModelInfo() log.debug('Got ModelInfo: %s', vars(self.info)) @@ -610,43 +647,64 @@ class Model(object): See :meth:`add_observer` to register an onchange callback. """ - async def _start_watch(): - self._watch_shutdown.clear() + async def _all_watcher(): try: - allwatcher = watcher.AllWatcher() - self._watch_conn = await self.connection.clone() - allwatcher.connect(self._watch_conn) - while True: - results = await allwatcher.Next() + allwatcher = client.AllWatcherFacade.from_connection( + self.connection) + while not self._watch_stopping.is_set(): + try: + results = await utils.run_with_interrupt( + allwatcher.Next(), + self._watch_stopping, + self.loop) + except JujuAPIError as e: + if 'watcher was stopped' not in str(e): + raise + if self._watch_stopping.is_set(): + # this shouldn't ever actually happen, because + # the event should trigger before the controller + # has a chance to tell us the watcher is stopped + # but handle it gracefully, just in case + break + # controller stopped our watcher for some reason + # but we're not actually stopping, so just restart it + log.warning( + 'Watcher: watcher stopped, restarting') + del allwatcher.Id + continue + except websockets.ConnectionClosed: + monitor = self.connection.monitor + if monitor.status == monitor.ERROR: + # closed unexpectedly, try to reopen + log.warning( + 'Watcher: connection closed, reopening') + await self.connection.reconnect() + del allwatcher.Id + continue + else: + # closed on request, go ahead and shutdown + break + if self._watch_stopping.is_set(): + await allwatcher.Stop() + break for delta in results.deltas: delta = get_entity_delta(delta) old_obj, new_obj = self.state.apply_delta(delta) - # XXX: Might not want to shield at this level - # We are shielding because when the watcher is - # canceled (on disconnect()), we don't want all of - # its children (every observer callback) to be - # canceled with it. So we shield them. But this means - # they can *never* be canceled. - await asyncio.shield( - self._notify_observers(delta, old_obj, new_obj), - loop=self.loop) + await self._notify_observers(delta, old_obj, new_obj) self._watch_received.set() except CancelledError: - log.debug('Closing watcher connection') - await self._watch_conn.close() - self._watch_shutdown.set() - self._watch_conn = None + pass + except Exception: + log.exception('Error in watcher') + raise + finally: + self._watch_stopped.set() log.debug('Starting watcher task') - self._watcher_task = self.loop.create_task(_start_watch()) - - def _stop_watching(self): - """Stop the asynchronous watch against this model. - - """ - log.debug('Stopping watcher task') - if self._watcher_task: - self._watcher_task.cancel() + self._watch_received.clear() + self._watch_stopping.clear() + self._watch_stopped.clear() + self.loop.create_task(_all_watcher()) async def _notify_observers(self, delta, old_obj, new_obj): """Call observing callbacks, notifying them of a change in model state @@ -744,14 +802,34 @@ class Model(object): 'zone=us-east-1a' - starts a machine in zone us-east-1s on AWS 'maas2.name' - acquire machine maas2.name on MAAS - :param dict constraints: Machine constraints + :param dict constraints: Machine constraints, which can contain the + the following keys:: + + arch : str + container : str + cores : int + cpu_power : int + instance_type : str + mem : int + root_disk : int + spaces : list(str) + tags : list(str) + virt_type : str + Example:: constraints={ 'mem': 256 * MB, + 'tags': ['virtual'], } - :param list disks: List of disk constraint dictionaries + :param list disks: List of disk constraint dictionaries, which can + contain the following keys:: + + count : int + pool : str + size : int + Example:: disks=[{ @@ -787,12 +865,11 @@ class Model(object): params.series = series # Submit the request. - client_facade = client.ClientFacade() - client_facade.connect(self.connection) + client_facade = client.ClientFacade.from_connection(self.connection) results = await client_facade.AddMachines([params]) error = results.machines[0].error if error: - raise ValueError("Error adding machine: %s", error.message) + raise ValueError("Error adding machine: %s" % error.message) machine_id = results.machines[0].machine log.debug('Added new machine %s', machine_id) return await self._wait_for_new('machine', machine_id) @@ -804,8 +881,7 @@ class Model(object): :param str relation2: '[:]' """ - app_facade = client.ApplicationFacade() - app_facade.connect(self.connection) + app_facade = client.ApplicationFacade.from_connection(self.connection) log.debug( 'Adding relation %s <-> %s', relation1, relation2) @@ -841,13 +917,15 @@ class Model(object): """ raise NotImplementedError() - def add_ssh_key(self, key): + async def add_ssh_key(self, user, key): """Add a public SSH key to this model. + :param str user: The username of the user :param str key: The public ssh key """ - raise NotImplementedError() + key_facade = client.KeyManagerFacade.from_connection(self.connection) + return await key_facade.AddKeys([key], user) add_ssh_keys = add_ssh_key def add_subnet(self, cidr_or_id, space, *zones): @@ -963,7 +1041,7 @@ class Model(object): :param dict bind: : pairs :param dict budget: : pairs :param str channel: Charm store channel from which to retrieve - the charm or bundle, e.g. 'development' + the charm or bundle, e.g. 'edge' :param dict config: Charm configuration dictionary :param constraints: Service constraints :type constraints: :class:`juju.Constraints` @@ -985,9 +1063,7 @@ class Model(object): TODO:: - - application_name is required; fill this in automatically if not - provided by caller - - series is required; how do we pick a default? + - support local resources """ if storage: @@ -1001,13 +1077,12 @@ class Model(object): os.path.isdir(entity_url) ) if is_local: - entity_id = entity_url + entity_id = entity_url.replace('local:', '') else: - entity = await self.charmstore.entity(entity_url) + entity = await self.charmstore.entity(entity_url, channel=channel) entity_id = entity['Id'] - client_facade = client.ClientFacade() - client_facade.connect(self.connection) + client_facade = client.ClientFacade.from_connection(self.connection) is_bundle = ((is_local and (Path(entity_id) / 'bundle.yaml').exists()) or @@ -1036,9 +1111,12 @@ class Model(object): application_name = entity['Meta']['charm-metadata']['Name'] if not series: series = self._get_series(entity_url, entity) - if not channel: - channel = 'stable' await client_facade.AddCharm(channel, entity_id) + # XXX: we're dropping local resources here, but we don't + # actually support them yet anyway + resources = await self._add_store_resources(application_name, + entity_id, + entity) else: # We have a local charm dir that needs to be uploaded charm_dir = os.path.abspath( @@ -1061,9 +1139,40 @@ class Model(object): storage=storage, channel=channel, num_units=num_units, - placement=parse_placement(to), + placement=parse_placement(to) ) + async def _add_store_resources(self, application, entity_url, entity=None): + if not entity: + # avoid extra charm store call if one was already made + entity = await self.charmstore.entity(entity_url) + resources = [ + { + 'description': resource['Description'], + 'fingerprint': resource['Fingerprint'], + 'name': resource['Name'], + 'path': resource['Path'], + 'revision': resource['Revision'], + 'size': resource['Size'], + 'type_': resource['Type'], + 'origin': 'store', + } for resource in entity['Meta']['resources'] + ] + + if not resources: + return None + + resources_facade = client.ResourcesFacade.from_connection( + self.connection) + response = await resources_facade.AddPendingResources( + tag.application(application), + entity_url, + [client.CharmResource(**resource) for resource in resources]) + resource_map = {resource['name']: pid + for resource, pid + in zip(resources, response.pending_ids)} + return resource_map + async def _deploy(self, charm_url, application, series, config, constraints, endpoint_bindings, resources, storage, channel=None, num_units=None, placement=None): @@ -1076,8 +1185,8 @@ class Model(object): config = yaml.dump({application: config}, default_flow_style=False) - app_facade = client.ApplicationFacade() - app_facade.connect(self.connection) + app_facade = client.ApplicationFacade.from_connection( + self.connection) app = client.ApplicationDeploy( charm_url=charm_url, @@ -1090,7 +1199,7 @@ class Model(object): num_units=num_units, resources=resources, storage=storage, - placement=placement, + placement=placement ) result = await app_facade.Deploy([app]) @@ -1099,9 +1208,9 @@ class Model(object): raise JujuError('\n'.join(errors)) return await self._wait_for_new('application', application) - def destroy(self): + async def destroy(self): """Terminate all machines and resources for this model. - + Is already implemented in controller.py. """ raise NotImplementedError() @@ -1109,8 +1218,7 @@ class Model(object): """Destroy units by name. """ - app_facade = client.ApplicationFacade() - app_facade.connect(self.connection) + app_facade = client.ApplicationFacade.from_connection(self.connection) log.debug( 'Destroying unit%s %s', @@ -1148,11 +1256,20 @@ class Model(object): """ raise NotImplementedError() - def get_config(self): + async def get_config(self): """Return the configuration settings for this model. + :returns: A ``dict`` mapping keys to `ConfigValue` instances, + which have `source` and `value` attributes. """ - raise NotImplementedError() + config_facade = client.ModelConfigFacade.from_connection( + self.connection + ) + result = await config_facade.ModelGet() + config = result.config + for key, value in config.items(): + config[key] = ConfigValue.from_json(value) + return config def get_constraints(self): """Return the machine constraints for this model. @@ -1160,14 +1277,21 @@ class Model(object): """ raise NotImplementedError() - def grant(self, username, acl='read'): + async def grant(self, username, acl='read'): """Grant a user access to this model. :param str username: Username :param str acl: Access control ('read' or 'write') """ - raise NotImplementedError() + controller_conn = await self.connection.controller() + model_facade = client.ModelManagerFacade.from_connection( + controller_conn) + user = tag.user(username) + model = tag.model(self.info.uuid) + changes = client.ModifyModelAccess(acl, 'grant', model, user) + await self.revoke(username) + return await model_facade.ModifyModelAccess([changes]) def import_ssh_key(self, identity): """Add a public SSH key from a trusted indentity source to this model. @@ -1178,14 +1302,11 @@ class Model(object): raise NotImplementedError() import_ssh_keys = import_ssh_key - def get_machines(self, machine, utc=False): + async def get_machines(self): """Return list of machines in this model. - :param str machine: Machine id, e.g. '0' - :param bool utc: Display time as UTC in RFC3339 format - """ - raise NotImplementedError() + return list(self.state.machines.keys()) def get_shares(self): """Return list of all users with access to this model. @@ -1199,11 +1320,16 @@ class Model(object): """ raise NotImplementedError() - def get_ssh_key(self): + async def get_ssh_key(self, raw_ssh=False): """Return known SSH keys for this model. + :param bool raw_ssh: if True, returns the raw ssh key, + else it's fingerprint """ - raise NotImplementedError() + key_facade = client.KeyManagerFacade.from_connection(self.connection) + entity = {'tag': tag.model(self.info.uuid)} + entities = client.Entities([entity]) + return await key_facade.ListKeys(entities, raw_ssh) get_ssh_keys = get_ssh_key def get_storage(self, filesystem=False, volume=False): @@ -1266,13 +1392,18 @@ class Model(object): raise NotImplementedError() remove_machines = remove_machine - def remove_ssh_key(self, *keys): + async def remove_ssh_key(self, user, key): """Remove a public SSH key(s) from this model. - :param str \*keys: Keys to remove + :param str key: Full ssh key + :param str user: Juju user to which the key is registered """ - raise NotImplementedError() + key_facade = client.KeyManagerFacade.from_connection(self.connection) + key = base64.b64decode(bytes(key.strip().split()[1].encode('ascii'))) + key = hashlib.md5(key).hexdigest() + key = ':'.join(a+b for a, b in zip(key[::2], key[1::2])) + await key_facade.DeleteKeys([key], user) remove_ssh_keys = remove_ssh_key def restore_backup( @@ -1296,14 +1427,19 @@ class Model(object): """ raise NotImplementedError() - def revoke(self, username, acl='read'): + async def revoke(self, username): """Revoke a user's access to this model. :param str username: Username to revoke - :param str acl: Access control ('read' or 'write') """ - raise NotImplementedError() + controller_conn = await self.connection.controller() + model_facade = client.ModelManagerFacade.from_connection( + controller_conn) + user = tag.user(username) + model = tag.model(self.info.uuid) + changes = client.ModifyModelAccess('read', 'revoke', model, user) + return await model_facade.ModifyModelAccess([changes]) def run(self, command, timeout=None): """Run command on all machines in this model. @@ -1314,13 +1450,19 @@ class Model(object): """ raise NotImplementedError() - def set_config(self, **config): + async def set_config(self, config): """Set configuration keys on this model. - :param \*\*config: Config key/values - + :param dict config: Mapping of config keys to either string values or + `ConfigValue` instances, as returned by `get_config`. """ - raise NotImplementedError() + config_facade = client.ModelConfigFacade.from_connection( + self.connection + ) + for key, value in config.items(): + if isinstance(value, ConfigValue): + config[key] = value.value + await config_facade.ModelSet(config) def set_constraints(self, constraints): """Set machine constraints on this model. @@ -1364,8 +1506,7 @@ class Model(object): :param bool utc: Display time as UTC in RFC3339 format """ - client_facade = client.ClientFacade() - client_facade.connect(self.connection) + client_facade = client.ClientFacade.from_connection(self.connection) return await client_facade.FullStatus(filters) def sync_tools( @@ -1445,8 +1586,8 @@ class Model(object): log.debug("Retrieving metrics for %s", ', '.join(tags) if tags else "all units") - metrics_facade = client.MetricsDebugFacade() - metrics_facade.connect(self.connection) + metrics_facade = client.MetricsDebugFacade.from_connection( + self.connection) entities = [client.Entity(tag) for tag in tags] metrics_result = await metrics_facade.GetMetrics(entities) @@ -1496,12 +1637,12 @@ class BundleHandler(object): for unit_name, unit in model.units.items(): app_units = self._units_by_app.setdefault(unit.application, []) app_units.append(unit_name) - self.client_facade = client.ClientFacade() - self.client_facade.connect(model.connection) - self.app_facade = client.ApplicationFacade() - self.app_facade.connect(model.connection) - self.ann_facade = client.AnnotationsFacade() - self.ann_facade.connect(model.connection) + self.client_facade = client.ClientFacade.from_connection( + model.connection) + self.app_facade = client.ApplicationFacade.from_connection( + model.connection) + self.ann_facade = client.AnnotationsFacade.from_connection( + model.connection) async def _handle_local_charms(self, bundle): """Search for references to local charms (i.e. filesystem paths) @@ -1565,9 +1706,6 @@ class BundleHandler(object): self.plan = await self.client_facade.GetBundleChanges( yaml.dump(self.bundle)) - if self.plan.errors: - raise JujuError('\n'.join(self.plan.errors)) - async def execute_plan(self): for step in self.plan.changes: method = getattr(self, step.method) @@ -1647,7 +1785,7 @@ class BundleHandler(object): results = await self.client_facade.AddMachines([params]) error = results.machines[0].error if error: - raise ValueError("Error adding machine: %s", error.message) + raise ValueError("Error adding machine: %s" % error.message) machine = results.machines[0].machine log.debug('Added new machine %s', machine) return machine @@ -1703,6 +1841,11 @@ class BundleHandler(object): """ # resolve indirect references charm = self.resolve(charm) + # the bundle plan doesn't actually do anything with resources, even + # though it ostensibly gives us something (None) for that param + if not charm.startswith('local:'): + resources = await self.model._add_store_resources(application, + charm) await self.model._deploy( charm_url=charm, application=application, @@ -1807,6 +1950,12 @@ class CharmStore(object): class CharmArchiveGenerator(object): + """ + Create a Zip archive of a local charm directory for upload to a controller. + + This is used automatically by + `Model.add_local_charm_dir <#juju.model.Model.add_local_charm_dir>`_. + """ def __init__(self, path): self.path = os.path.abspath(os.path.expanduser(path))