Commit 17583c8b authored by Mark Beierl's avatar Mark Beierl
Browse files

Merge branch 'master' into 'master'

Update squid charm

See merge request !105
parents e2197769 9c012576
Pipeline #143 passed with stage
in 1 minute and 34 seconds
Welcome to The Operator Framework's documentation!
==================================================
.. toctree::
:maxdepth: 2
:caption: Contents:
ops package
===========
.. automodule:: ops
Submodules
----------
ops.charm module
----------------
.. automodule:: ops.charm
ops.framework module
--------------------
.. automodule:: ops.framework
ops.jujuversion module
----------------------
.. automodule:: ops.jujuversion
ops.log module
--------------
.. automodule:: ops.log
ops.main module
---------------
.. automodule:: ops.main
ops.model module
----------------
.. automodule:: ops.model
ops.testing module
------------------
.. automodule:: ops.testing
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
# Copyright 2019-2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import collections.abc
import inspect
import keyword
import marshal
import os
import pdb
import pickle
import re
import sqlite3
import sys
import types
import weakref
from datetime import timedelta
from ops import charm
class Handle:
"""Handle defines a name for an object in the form of a hierarchical path.
The provided parent is the object (or that object's handle) that this handle
sits under, or None if the object identified by this handle stands by itself
as the root of its own hierarchy.
The handle kind is a string that defines a namespace so objects with the
same parent and kind will have unique keys.
The handle key is a string uniquely identifying the object. No other objects
under the same parent and kind may have the same key.
"""
def __init__(self, parent, kind, key):
if parent and not isinstance(parent, Handle):
parent = parent.handle
self._parent = parent
self._kind = kind
self._key = key
if parent:
if key:
self._path = "{}/{}[{}]".format(parent, kind, key)
else:
self._path = "{}/{}".format(parent, kind)
else:
if key:
self._path = "{}[{}]".format(kind, key)
else:
self._path = "{}".format(kind)
def nest(self, kind, key):
return Handle(self, kind, key)
def __hash__(self):
return hash((self.parent, self.kind, self.key))
def __eq__(self, other):
return (self.parent, self.kind, self.key) == (other.parent, other.kind, other.key)
def __str__(self):
return self.path
@property
def parent(self):
return self._parent
@property
def kind(self):
return self._kind
@property
def key(self):
return self._key
@property
def path(self):
return self._path
@classmethod
def from_path(cls, path):
handle = None
for pair in path.split("/"):
pair = pair.split("[")
good = False
if len(pair) == 1:
kind, key = pair[0], None
good = True
elif len(pair) == 2:
kind, key = pair
if key and key[-1] == ']':
key = key[:-1]
good = True
if not good:
raise RuntimeError("attempted to restore invalid handle path {}".format(path))
handle = Handle(handle, kind, key)
return handle
class EventBase:
def __init__(self, handle):
self.handle = handle
self.deferred = False
def defer(self):
self.deferred = True
def snapshot(self):
"""Return the snapshot data that should be persisted.
Subclasses must override to save any custom state.
"""
return None
def restore(self, snapshot):
"""Restore the value state from the given snapshot.
Subclasses must override to restore their custom state.
"""
self.deferred = False
class EventSource:
"""EventSource wraps an event type with a descriptor to facilitate observing and emitting.
It is generally used as:
class SomethingHappened(EventBase):
pass
class SomeObject(Object):
something_happened = EventSource(SomethingHappened)
With that, instances of that type will offer the someobj.something_happened
attribute which is a BoundEvent and may be used to emit and observe the event.
"""
def __init__(self, event_type):
if not isinstance(event_type, type) or not issubclass(event_type, EventBase):
raise RuntimeError(
'Event requires a subclass of EventBase as an argument, got {}'.format(event_type))
self.event_type = event_type
self.event_kind = None
self.emitter_type = None
def _set_name(self, emitter_type, event_kind):
if self.event_kind is not None:
raise RuntimeError(
'EventSource({}) reused as {}.{} and {}.{}'.format(
self.event_type.__name__,
self.emitter_type.__name__,
self.event_kind,
emitter_type.__name__,
event_kind,
))
self.event_kind = event_kind
self.emitter_type = emitter_type
def __get__(self, emitter, emitter_type=None):
if emitter is None:
return self
# Framework might not be available if accessed as CharmClass.on.event
# rather than charm_instance.on.event, but in that case it couldn't be
# emitted anyway, so there's no point to registering it.
framework = getattr(emitter, 'framework', None)
if framework is not None:
framework.register_type(self.event_type, emitter, self.event_kind)
return BoundEvent(emitter, self.event_type, self.event_kind)
class BoundEvent:
def __repr__(self):
return '<BoundEvent {} bound to {}.{} at {}>'.format(
self.event_type.__name__,
type(self.emitter).__name__,
self.event_kind,
hex(id(self)),
)
def __init__(self, emitter, event_type, event_kind):
self.emitter = emitter
self.event_type = event_type
self.event_kind = event_kind
def emit(self, *args, **kwargs):
"""Emit event to all registered observers.
The current storage state is committed before and after each observer is notified.
"""
framework = self.emitter.framework
key = framework._next_event_key()
event = self.event_type(Handle(self.emitter, self.event_kind, key), *args, **kwargs)
framework._emit(event)
class HandleKind:
"""Helper descriptor to define the Object.handle_kind field.
The handle_kind for an object defaults to its type name, but it may
be explicitly overridden if desired.
"""
def __get__(self, obj, obj_type):
kind = obj_type.__dict__.get("handle_kind")
if kind:
return kind
return obj_type.__name__
class _Metaclass(type):
"""Helper class to ensure proper instantiation of Object-derived classes.
This class currently has a single purpose: events derived from EventSource
that are class attributes of Object-derived classes need to be told what
their name is in that class. For example, in
class SomeObject(Object):
something_happened = EventSource(SomethingHappened)
the instance of EventSource needs to know it's called 'something_happened'.
Starting from python 3.6 we could use __set_name__ on EventSource for this,
but until then this (meta)class does the equivalent work.
TODO: when we drop support for 3.5 drop this class, and rename _set_name in
EventSource to __set_name__; everything should continue to work.
"""
def __new__(typ, *a, **kw):
k = super().__new__(typ, *a, **kw)
# k is now the Object-derived class; loop over its class attributes
for n, v in vars(k).items():
# we could do duck typing here if we want to support
# non-EventSource-derived shenanigans. We don't.
if isinstance(v, EventSource):
# this is what 3.6+ does automatically for us:
v._set_name(k, n)
return k
class Object(metaclass=_Metaclass):
handle_kind = HandleKind()
def __init__(self, parent, key):
kind = self.handle_kind
if isinstance(parent, Framework):
self.framework = parent
# Avoid Framework instances having a circular reference to themselves.
if self.framework is self:
self.framework = weakref.proxy(self.framework)
self.handle = Handle(None, kind, key)
else:
self.framework = parent.framework
self.handle = Handle(parent, kind, key)
self.framework._track(self)
# TODO Detect conflicting handles here.
@property
def model(self):
return self.framework.model
class ObjectEvents(Object):
"""Convenience type to allow defining .on attributes at class level."""
handle_kind = "on"
def __init__(self, parent=None, key=None):
if parent is not None:
super().__init__(parent, key)
else:
self._cache = weakref.WeakKeyDictionary()
def __get__(self, emitter, emitter_type):
if emitter is None:
return self
instance = self._cache.get(emitter)
if instance is None:
# Same type, different instance, more data. Doing this unusual construct
# means people can subclass just this one class to have their own 'on'.
instance = self._cache[emitter] = type(self)(emitter)
return instance
@classmethod
def define_event(cls, event_kind, event_type):
"""Define an event on this type at runtime.
cls: a type to define an event on.
event_kind: an attribute name that will be used to access the
event. Must be a valid python identifier, not be a keyword
or an existing attribute.
event_type: a type of the event to define.
"""
prefix = 'unable to define an event with event_kind that '
if not event_kind.isidentifier():
raise RuntimeError(prefix + 'is not a valid python identifier: ' + event_kind)
elif keyword.iskeyword(event_kind):
raise RuntimeError(prefix + 'is a python keyword: ' + event_kind)
try:
getattr(cls, event_kind)
raise RuntimeError(
prefix + 'overlaps with an existing type {} attribute: {}'.format(cls, event_kind))
except AttributeError:
pass
event_descriptor = EventSource(event_type)
event_descriptor._set_name(cls, event_kind)
setattr(cls, event_kind, event_descriptor)
def events(self):
"""Return a mapping of event_kinds to bound_events for all available events.
"""
events_map = {}
# We have to iterate over the class rather than instance to allow for properties which
# might call this method (e.g., event views), leading to infinite recursion.
for attr_name, attr_value in inspect.getmembers(type(self)):
if isinstance(attr_value, EventSource):
# We actually care about the bound_event, however, since it
# provides the most info for users of this method.
event_kind = attr_name
bound_event = getattr(self, event_kind)
events_map[event_kind] = bound_event
return events_map
def __getitem__(self, key):
return PrefixedEvents(self, key)
class PrefixedEvents:
def __init__(self, emitter, key):
self._emitter = emitter
self._prefix = key.replace("-", "_") + '_'
def __getattr__(self, name):
return getattr(self._emitter, self._prefix + name)
class PreCommitEvent(EventBase):
pass
class CommitEvent(EventBase):
pass
class FrameworkEvents(ObjectEvents):
pre_commit = EventSource(PreCommitEvent)
commit = EventSource(CommitEvent)
class NoSnapshotError(Exception):
def __init__(self, handle_path):
self.handle_path = handle_path
def __str__(self):
return 'no snapshot data found for {} object'.format(self.handle_path)
class NoTypeError(Exception):
def __init__(self, handle_path):
self.handle_path = handle_path
def __str__(self):
return "cannot restore {} since no class was registered for it".format(self.handle_path)
class SQLiteStorage:
DB_LOCK_TIMEOUT = timedelta(hours=1)
def __init__(self, filename):
# The isolation_level argument is set to None such that the implicit
# transaction management behavior of the sqlite3 module is disabled.
self._db = sqlite3.connect(str(filename),
isolation_level=None,
timeout=self.DB_LOCK_TIMEOUT.total_seconds())
self._setup()
def _setup(self):
# Make sure that the database is locked until the connection is closed,
# not until the transaction ends.
self._db.execute("PRAGMA locking_mode=EXCLUSIVE")
c = self._db.execute("BEGIN")
c.execute("SELECT count(name) FROM sqlite_master WHERE type='table' AND name='snapshot'")
if c.fetchone()[0] == 0:
# Keep in mind what might happen if the process dies somewhere below.
# The system must not be rendered permanently broken by that.
self._db.execute("CREATE TABLE snapshot (handle TEXT PRIMARY KEY, data BLOB)")
self._db.execute('''
CREATE TABLE notice (
sequence INTEGER PRIMARY KEY AUTOINCREMENT,
event_path TEXT,
observer_path TEXT,
method_name TEXT)
''')
self._db.commit()
def close(self):
self._db.close()
def commit(self):
self._db.commit()
# There's commit but no rollback. For abort to be supported, we'll need logic that
# can rollback decisions made by third-party code in terms of the internal state
# of objects that have been snapshotted, and hooks to let them know about it and
# take the needed actions to undo their logic until the last snapshot.
# This is doable but will increase significantly the chances for mistakes.
def save_snapshot(self, handle_path, snapshot_data):
self._db.execute("REPLACE INTO snapshot VALUES (?, ?)", (handle_path, snapshot_data))
def load_snapshot(self, handle_path):
c = self._db.cursor()
c.execute("SELECT data FROM snapshot WHERE handle=?", (handle_path,))
row = c.fetchone()
if row:
return row[0]
return None
def drop_snapshot(self, handle_path):
self._db.execute("DELETE FROM snapshot WHERE handle=?", (handle_path,))
def save_notice(self, event_path, observer_path, method_name):
self._db.execute('INSERT INTO notice VALUES (NULL, ?, ?, ?)',
(event_path, observer_path, method_name))
def drop_notice(self, event_path, observer_path, method_name):
self._db.execute('''
DELETE FROM notice
WHERE event_path=?
AND observer_path=?
AND method_name=?
''', (event_path, observer_path, method_name))
def notices(self, event_path):
if event_path:
c = self._db.execute('''
SELECT event_path, observer_path, method_name
FROM notice
WHERE event_path=?
ORDER BY sequence
''', (event_path,))
else:
c = self._db.execute('''
SELECT event_path, observer_path, method_name
FROM notice
ORDER BY sequence
''')
while True:
rows = c.fetchmany()
if not rows:
break
for row in rows:
yield tuple(row)
# the message to show to the user when a pdb breakpoint goes active
_BREAKPOINT_WELCOME_MESSAGE = """
Starting pdb to debug charm operator.
Run `h` for help, `c` to continue, or `exit`/CTRL-d to abort.
Future breakpoints may interrupt execution again.
More details at https://discourse.jujucharms.com/t/debugging-charm-hooks
"""
class Framework(Object):
on = FrameworkEvents()
# Override properties from Object so that we can set them in __init__.
model = None
meta = None
charm_dir = None
def __init__(self, data_path, charm_dir, meta, model):
super().__init__(self, None)
self._data_path = data_path
self.charm_dir = charm_dir
self.meta = meta
self.model = model
self._observers = [] # [(observer_path, method_name, parent_path, event_key)]
self._observer = weakref.WeakValueDictionary() # {observer_path: observer}
self._objects = weakref.WeakValueDictionary()
self._type_registry = {} # {(parent_path, kind): cls}
self._type_known = set() # {cls}
self._storage = SQLiteStorage(data_path)
# We can't use the higher-level StoredState because it relies on events.
self.register_type(StoredStateData, None, StoredStateData.handle_kind)
stored_handle = Handle(None, StoredStateData.handle_kind, '_stored')
try:
self._stored = self.load_snapshot(stored_handle)
except NoSnapshotError:
self._stored = StoredStateData(self, '_stored')
self._stored['event_count'] = 0
# Hook into builtin breakpoint, so if Python >= 3.7, devs will be able to just do
# breakpoint(); if Python < 3.7, this doesn't affect anything
sys.breakpointhook = self.breakpoint
# Flag to indicate that we already presented the welcome message in a debugger breakpoint
self._breakpoint_welcomed = False
# Parse once the env var, which may be used multiple times later
debug_at = os.environ.get('JUJU_DEBUG_AT')
self._juju_debug_at = debug_at.split(',') if debug_at else ()
def close(self):
self._storage.close()
def _track(self, obj):
"""Track object and ensure it is the only object created using its handle path."""
if obj is self:
# Framework objects don't track themselves
return
if obj.handle.path in self.framework._objects:
raise RuntimeError(
'two objects claiming to be {} have been created'.format(obj.handle.path))
self._objects[obj.handle.path] = obj
def _forget(self, obj):
"""Stop tracking the given object. See also _track."""
self._objects.pop(obj.handle.path, None)
def commit(self):
# Give a chance for objects to persist data they want to before a commit is made.
self.on.pre_commit.emit()
# Make sure snapshots are saved by instances of StoredStateData. Any possible state
# modifications in on_commit handlers of instances of other classes will not be persisted.
self.on.commit.emit()
# Save our event count after all events have been emitted.
self.save_snapshot(self._stored)
self._storage.commit()
def register_type(self, cls, parent, kind=None):
if parent and not isinstance(parent, Handle):
parent = parent.handle
if parent:
parent_path = parent.path
else:
parent_path = None
if not kind:
kind = cls.handle_kind
self._type_registry[(parent_path, kind)] = cls
self._type_known.add(cls)
def save_snapshot(self, value):
"""Save a persistent snapshot of the provided value.
The provided value must implement the following interface:
value.handle = Handle(...)
value.snapshot() => {...} # Simple builtin types only.
value.restore(snapshot) # Restore custom state from prior snapshot.
"""
if type(value) not in self._type_known:
raise RuntimeError(
'cannot save {} values before registering that type'.format(type(value).__name__))
data = value.snapshot()
# Use marshal as a validator, enforcing the use of simple types, as we later the
# information is really pickled, which is too error prone for future evolution of the
# stored data (e.g. if the developer stores a custom object and later changes its
# class name; when unpickling the original class will not be there and event
# data loading will fail).
try:
marshal.dumps(data)
except ValueError:
msg = "unable to save the data for {}, it must contain only simple types: {!r}"
raise ValueError(msg.format(value.__class__.__name__, data))
# Use pickle for serialization, so the value remains portable.
raw_data = pickle.dumps(data)
self._storage.save_snapshot(value.handle.path, raw_data)
def load_snapshot(self, handle):
parent_path = None
if handle.parent:
parent_path = handle.parent.path
cls = self._type_registry.get((parent_path, handle.kind))
if not cls:
raise NoTypeError(handle.path)
raw_data = self._storage.load_snapshot(handle.path)
if not raw_data:
raise NoSnapshotError(handle.path)
data = pickle.loads(raw_data)
obj = cls.__new__(cls)
obj.framework = self
obj.handle = handle
obj.restore(data)
self._track(obj)
return obj
def drop_snapshot(self, handle):
self._storage.drop_snapshot(handle.path)
def observe(self, bound_event, observer):
"""Register observer to be called when bound_event is emitted.
The bound_event is generally provided as an attribute of the object that emits
the event, and is created in this style:
class SomeObject:
something_happened = Event(SomethingHappened)
That event may be observed as:
framework.observe(someobj.something_happened, self.on_something_happened)
If the method to be called follows the name convention "on_<event name>", it
may be omitted from the observe call. That means the above is equivalent to:
framework.observe(someobj.something_happened, self)
"""
if not isinstance(bound_event, BoundEvent):
raise RuntimeError(
'Framework.observe requires a BoundEvent as second parameter, got {}'.format(
bound_event))
event_type = bound_event.event_type
event_kind = bound_event.event_kind
emitter = bound_event.emitter
self.register_type(event_type, emitter, event_kind)
if hasattr(emitter, "handle"):
emitter_path = emitter.handle.path
else:
raise RuntimeError(
'event emitter {} must have a "handle" attribute'.format(type(emitter).__name__))
method_name = None
if isinstance(observer, types.MethodType):
method_name = observer.__name__
observer = observer.__self__
else:
method_name = "on_" + event_kind
if not hasattr(observer, method_name):
raise RuntimeError(
'Observer method not provided explicitly'
' and {} type has no "{}" method'.format(type(observer).__name__,
method_name))
# Validate that the method has an acceptable call signature.
sig = inspect.signature(getattr(observer, method_name))
# Self isn't included in the params list, so the first arg will be the event.
extra_params = list(sig.parameters.values())[1:]
if not sig.parameters:
raise TypeError(
'{}.{} must accept event parameter'.format(type(observer).__name__, method_name))
elif any(param.default is inspect.Parameter.empty for param in extra_params):
# Allow for additional optional params, since there's no reason to exclude them, but
# required params will break.
raise TypeError(
'{}.{} has extra required parameter'.format(type(observer).__name__, method_name))
# TODO Prevent the exact same parameters from being registered more than once.
self._observer[observer.handle.path] = observer
self._observers.append((observer.handle.path, method_name, emitter_path, event_kind))
def _next_event_key(self):
"""Return the next event key that should be used, incrementing the internal counter."""
# Increment the count first; this means the keys will start at 1, and 0
# means no events have been emitted.
self._stored['event_count'] += 1
return str(self._stored['event_count'])
def _emit(self, event):
"""See BoundEvent.emit for the public way to call this."""
# Save the event for all known observers before the first notification
# takes place, so that either everyone interested sees it, or nobody does.
self.save_snapshot(event)
event_path = event.handle.path
event_kind = event.handle.kind
parent_path = event.handle.parent.path
# TODO Track observers by (parent_path, event_kind) rather than as a list of
# all observers. Avoiding linear search through all observers for every event
for observer_path, method_name, _parent_path, _event_kind in self._observers:
if _parent_path != parent_path:
continue
if _event_kind and _event_kind != event_kind:
continue
# Again, only commit this after all notices are saved.
self._storage.save_notice(event_path, observer_path, method_name)
self._reemit(event_path)
def reemit(self):
"""Reemit previously deferred events to the observers that deferred them.
Only the specific observers that have previously deferred the event will be
notified again. Observers that asked to be notified about events after it's
been first emitted won't be notified, as that would mean potentially observing
events out of order.
"""
self._reemit()
def _reemit(self, single_event_path=None):
last_event_path = None
deferred = True
for event_path, observer_path, method_name in self._storage.notices(single_event_path):
event_handle = Handle.from_path(event_path)
if last_event_path != event_path:
if not deferred:
self._storage.drop_snapshot(last_event_path)
last_event_path = event_path
deferred = False
try:
event = self.load_snapshot(event_handle)
except NoTypeError:
self._storage.drop_notice(event_path, observer_path, method_name)
continue
event.deferred = False
observer = self._observer.get(observer_path)
if observer:
custom_handler = getattr(observer, method_name, None)
if custom_handler:
event_is_from_juju = isinstance(event, charm.HookEvent)
event_is_action = isinstance(event, charm.ActionEvent)
if (event_is_from_juju or event_is_action) and 'hook' in self._juju_debug_at:
# Present the welcome message and run under PDB.
self._show_debug_code_message()
pdb.runcall(custom_handler, event)
else:
# Regular call to the registered method.
custom_handler(event)
if event.deferred:
deferred = True
else:
self._storage.drop_notice(event_path, observer_path, method_name)
# We intentionally consider this event to be dead and reload it from
# scratch in the next path.
self.framework._forget(event)
if not deferred:
self._storage.drop_snapshot(last_event_path)
def _show_debug_code_message(self):
"""Present the welcome message (only once!) when using debugger functionality."""
if not self._breakpoint_welcomed:
self._breakpoint_welcomed = True
print(_BREAKPOINT_WELCOME_MESSAGE, file=sys.stderr, end='')
def breakpoint(self, name=None):
"""Add breakpoint, optionally named, at the place where this method is called.
For the breakpoint to be activated the JUJU_DEBUG_AT environment variable
must be set to "all" or to the specific name parameter provided, if any. In every
other situation calling this method does nothing.
The framework also provides a standard breakpoint named "hook", that will
stop execution when a hook event is about to be handled.
For those reasons, the "all" and "hook" breakpoint names are reserved.
"""
# If given, validate the name comply with all the rules
if name is not None:
if not isinstance(name, str):
raise TypeError('breakpoint names must be strings')
if name in ('hook', 'all'):
raise ValueError('breakpoint names "all" and "hook" are reserved')
if not re.match(r'^[a-z0-9]([a-z0-9\-]*[a-z0-9])?$', name):
raise ValueError('breakpoint names must look like "foo" or "foo-bar"')
indicated_breakpoints = self._juju_debug_at
if 'all' in indicated_breakpoints or name in indicated_breakpoints:
self._show_debug_code_message()
# If we call set_trace() directly it will open the debugger *here*, so indicating
# it to use our caller's frame
code_frame = inspect.currentframe().f_back
pdb.Pdb().set_trace(code_frame)
class StoredStateData(Object):
def __init__(self, parent, attr_name):
super().__init__(parent, attr_name)
self._cache = {}
self.dirty = False
def __getitem__(self, key):
return self._cache.get(key)
def __setitem__(self, key, value):
self._cache[key] = value
self.dirty = True
def __contains__(self, key):
return key in self._cache
def snapshot(self):
return self._cache
def restore(self, snapshot):
self._cache = snapshot
self.dirty = False
def on_commit(self, event):
if self.dirty:
self.framework.save_snapshot(self)
self.dirty = False
class BoundStoredState:
def __init__(self, parent, attr_name):
parent.framework.register_type(StoredStateData, parent)
handle = Handle(parent, StoredStateData.handle_kind, attr_name)
try:
data = parent.framework.load_snapshot(handle)
except NoSnapshotError:
data = StoredStateData(parent, attr_name)
# __dict__ is used to avoid infinite recursion.
self.__dict__["_data"] = data
self.__dict__["_attr_name"] = attr_name
parent.framework.observe(parent.framework.on.commit, self._data)
def __getattr__(self, key):
# "on" is the only reserved key that can't be used in the data map.
if key == "on":
return self._data.on
if key not in self._data:
raise AttributeError("attribute '{}' is not stored".format(key))
return _wrap_stored(self._data, self._data[key])
def __setattr__(self, key, value):
if key == "on":
raise AttributeError("attribute 'on' is reserved and cannot be set")
value = _unwrap_stored(self._data, value)
if not isinstance(value, (type(None), int, float, str, bytes, list, dict, set)):
raise AttributeError(
'attribute {!r} cannot be a {}: must be int/float/dict/list/etc'.format(
key, type(value).__name__))
self._data[key] = _unwrap_stored(self._data, value)
def set_default(self, **kwargs):
""""Set the value of any given key if it has not already been set"""
for k, v in kwargs.items():
if k not in self._data:
self._data[k] = v
class StoredState:
"""A class used to store data the charm needs persisted across invocations.
Example::
class MyClass(Object):
_stored = StoredState()
Instances of `MyClass` can transparently save state between invocations by
setting attributes on `_stored`. Initial state should be set with
`set_default` on the bound object, that is::
class MyClass(Object):
_stored = StoredState()
def __init__(self, parent, key):
super().__init__(parent, key)
self._stored.set_default(seen=set())
self.framework.observe(self.on.seen, self._on_seen)
def _on_seen(self, event):
self._stored.seen.add(event.uuid)
"""
def __init__(self):
self.parent_type = None
self.attr_name = None
def __get__(self, parent, parent_type=None):
if self.parent_type is not None and self.parent_type not in parent_type.mro():
# the StoredState instance is being shared between two unrelated classes
# -> unclear what is exepcted of us -> bail out
raise RuntimeError(
'StoredState shared by {} and {}'.format(
self.parent_type.__name__, parent_type.__name__))
if parent is None:
# accessing via the class directly (e.g. MyClass.stored)
return self
bound = None
if self.attr_name is not None:
bound = parent.__dict__.get(self.attr_name)
if bound is not None:
# we already have the thing from a previous pass, huzzah
return bound
# need to find ourselves amongst the parent's bases
for cls in parent_type.mro():
for attr_name, attr_value in cls.__dict__.items():
if attr_value is not self:
continue
# we've found ourselves! is it the first time?
if bound is not None:
# the StoredState instance is being stored in two different
# attributes -> unclear what is expected of us -> bail out
raise RuntimeError("StoredState shared by {0}.{1} and {0}.{2}".format(
cls.__name__, self.attr_name, attr_name))
# we've found ourselves for the first time; save where, and bind the object
self.attr_name = attr_name
self.parent_type = cls
bound = BoundStoredState(parent, attr_name)
if bound is not None:
# cache the bound object to avoid the expensive lookup the next time
# (don't use setattr, to keep things symmetric with the fast-path lookup above)
parent.__dict__[self.attr_name] = bound
return bound
raise AttributeError(
'cannot find {} attribute in type {}'.format(
self.__class__.__name__, parent_type.__name__))
def _wrap_stored(parent_data, value):
t = type(value)
if t is dict:
return StoredDict(parent_data, value)
if t is list:
return StoredList(parent_data, value)
if t is set:
return StoredSet(parent_data, value)
return value
def _unwrap_stored(parent_data, value):
t = type(value)
if t is StoredDict or t is StoredList or t is StoredSet:
return value._under
return value
class StoredDict(collections.abc.MutableMapping):
def __init__(self, stored_data, under):
self._stored_data = stored_data
self._under = under
def __getitem__(self, key):
return _wrap_stored(self._stored_data, self._under[key])
def __setitem__(self, key, value):
self._under[key] = _unwrap_stored(self._stored_data, value)
self._stored_data.dirty = True
def __delitem__(self, key):
del self._under[key]
self._stored_data.dirty = True
def __iter__(self):
return self._under.__iter__()
def __len__(self):
return len(self._under)
def __eq__(self, other):
if isinstance(other, StoredDict):
return self._under == other._under
elif isinstance(other, collections.abc.Mapping):
return self._under == other
else:
return NotImplemented
class StoredList(collections.abc.MutableSequence):
def __init__(self, stored_data, under):
self._stored_data = stored_data
self._under = under
def __getitem__(self, index):
return _wrap_stored(self._stored_data, self._under[index])
def __setitem__(self, index, value):
self._under[index] = _unwrap_stored(self._stored_data, value)
self._stored_data.dirty = True
def __delitem__(self, index):
del self._under[index]
self._stored_data.dirty = True
def __len__(self):
return len(self._under)
def insert(self, index, value):
self._under.insert(index, value)
self._stored_data.dirty = True
def append(self, value):
self._under.append(value)
self._stored_data.dirty = True
def __eq__(self, other):
if isinstance(other, StoredList):
return self._under == other._under
elif isinstance(other, collections.abc.Sequence):
return self._under == other
else:
return NotImplemented
def __lt__(self, other):
if isinstance(other, StoredList):
return self._under < other._under
elif isinstance(other, collections.abc.Sequence):
return self._under < other
else:
return NotImplemented
def __le__(self, other):
if isinstance(other, StoredList):
return self._under <= other._under
elif isinstance(other, collections.abc.Sequence):
return self._under <= other
else:
return NotImplemented
def __gt__(self, other):
if isinstance(other, StoredList):
return self._under > other._under
elif isinstance(other, collections.abc.Sequence):
return self._under > other
else:
return NotImplemented
def __ge__(self, other):
if isinstance(other, StoredList):
return self._under >= other._under
elif isinstance(other, collections.abc.Sequence):
return self._under >= other
else:
return NotImplemented
class StoredSet(collections.abc.MutableSet):
def __init__(self, stored_data, under):
self._stored_data = stored_data
self._under = under
def add(self, key):
self._under.add(key)
self._stored_data.dirty = True
def discard(self, key):
self._under.discard(key)
self._stored_data.dirty = True
def __contains__(self, key):
return key in self._under
def __iter__(self):
return self._under.__iter__()
def __len__(self):
return len(self._under)
@classmethod
def _from_iterable(cls, it):
"""Construct an instance of the class from any iterable input.
Per https://docs.python.org/3/library/collections.abc.html
if the Set mixin is being used in a class with a different constructor signature,
you will need to override _from_iterable() with a classmethod that can construct
new instances from an iterable argument.
"""
return set(it)
def __le__(self, other):
if isinstance(other, StoredSet):
return self._under <= other._under
elif isinstance(other, collections.abc.Set):
return self._under <= other
else:
return NotImplemented
def __ge__(self, other):
if isinstance(other, StoredSet):
return self._under >= other._under
elif isinstance(other, collections.abc.Set):
return self._under >= other
else:
return NotImplemented
def __eq__(self, other):
if isinstance(other, StoredSet):
return self._under == other._under
elif isinstance(other, collections.abc.Set):
return self._under == other
else:
return NotImplemented
# Copyright 2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from functools import total_ordering
@total_ordering
class JujuVersion:
PATTERN = r'''^
(?P<major>\d{1,9})\.(?P<minor>\d{1,9}) # <major> and <minor> numbers are always there
((?:\.|-(?P<tag>[a-z]+))(?P<patch>\d{1,9}))? # sometimes with .<patch> or -<tag><patch>
(\.(?P<build>\d{1,9}))?$ # and sometimes with a <build> number.
'''
def __init__(self, version):
m = re.match(self.PATTERN, version, re.VERBOSE)
if not m:
raise RuntimeError('"{}" is not a valid Juju version string'.format(version))
d = m.groupdict()
self.major = int(m.group('major'))
self.minor = int(m.group('minor'))
self.tag = d['tag'] or ''
self.patch = int(d['patch'] or 0)
self.build = int(d['build'] or 0)
def __repr__(self):
if self.tag:
s = '{}.{}-{}{}'.format(self.major, self.minor, self.tag, self.patch)
else:
s = '{}.{}.{}'.format(self.major, self.minor, self.patch)
if self.build > 0:
s += '.{}'.format(self.build)
return s
def __eq__(self, other):
if self is other:
return True
if isinstance(other, str):
other = type(self)(other)
elif not isinstance(other, JujuVersion):
raise RuntimeError('cannot compare Juju version "{}" with "{}"'.format(self, other))
return (
self.major == other.major
and self.minor == other.minor
and self.tag == other.tag
and self.build == other.build
and self.patch == other.patch)
def __lt__(self, other):
if self is other:
return False
if isinstance(other, str):
other = type(self)(other)
elif not isinstance(other, JujuVersion):
raise RuntimeError('cannot compare Juju version "{}" with "{}"'.format(self, other))
if self.major != other.major:
return self.major < other.major
elif self.minor != other.minor:
return self.minor < other.minor
elif self.tag != other.tag:
if not self.tag:
return False
elif not other.tag:
return True
return self.tag < other.tag
elif self.patch != other.patch:
return self.patch < other.patch
elif self.build != other.build:
return self.build < other.build
return False
#!/usr/bin/env python3
# Copyright 2019 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from pathlib import Path
import yaml
import ops.charm
import ops.framework
import ops.model
import logging
from ops.log import setup_root_logging
CHARM_STATE_FILE = '.unit-state.db'
logger = logging.getLogger()
def _get_charm_dir():
charm_dir = os.environ.get("JUJU_CHARM_DIR")
if charm_dir is None:
# Assume $JUJU_CHARM_DIR/lib/op/main.py structure.
charm_dir = Path('{}/../../..'.format(__file__)).resolve()
else:
charm_dir = Path(charm_dir).resolve()
return charm_dir
def _load_metadata(charm_dir):
metadata = yaml.safe_load((charm_dir / 'metadata.yaml').read_text())
actions_meta = charm_dir / 'actions.yaml'
if actions_meta.exists():
actions_metadata = yaml.safe_load(actions_meta.read_text())
else:
actions_metadata = {}
return metadata, actions_metadata
def _create_event_link(charm, bound_event):
"""Create a symlink for a particular event.
charm -- A charm object.
bound_event -- An event for which to create a symlink.
"""
if issubclass(bound_event.event_type, ops.charm.HookEvent):
event_dir = charm.framework.charm_dir / 'hooks'
event_path = event_dir / bound_event.event_kind.replace('_', '-')
elif issubclass(bound_event.event_type, ops.charm.ActionEvent):
if not bound_event.event_kind.endswith("_action"):
raise RuntimeError(
'action event name {} needs _action suffix'.format(bound_event.event_kind))
event_dir = charm.framework.charm_dir / 'actions'
# The event_kind is suffixed with "_action" while the executable is not.
event_path = event_dir / bound_event.event_kind[:-len('_action')].replace('_', '-')
else:
raise RuntimeError(
'cannot create a symlink: unsupported event type {}'.format(bound_event.event_type))
event_dir.mkdir(exist_ok=True)
if not event_path.exists():
# CPython has different implementations for populating sys.argv[0] for Linux and Windows.
# For Windows it is always an absolute path (any symlinks are resolved)
# while for Linux it can be a relative path.
target_path = os.path.relpath(os.path.realpath(sys.argv[0]), str(event_dir))
# Ignore the non-symlink files or directories
# assuming the charm author knows what they are doing.
logger.debug(
'Creating a new relative symlink at %s pointing to %s',
event_path, target_path)
event_path.symlink_to(target_path)
def _setup_event_links(charm_dir, charm):
"""Set up links for supported events that originate from Juju.
Whether a charm can handle an event or not can be determined by
introspecting which events are defined on it.
Hooks or actions are created as symlinks to the charm code file
which is determined by inspecting symlinks provided by the charm
author at hooks/install or hooks/start.
charm_dir -- A root directory of the charm.
charm -- An instance of the Charm class.
"""
for bound_event in charm.on.events().values():
# Only events that originate from Juju need symlinks.
if issubclass(bound_event.event_type, (ops.charm.HookEvent, ops.charm.ActionEvent)):
_create_event_link(charm, bound_event)
def _emit_charm_event(charm, event_name):
"""Emits a charm event based on a Juju event name.
charm -- A charm instance to emit an event from.
event_name -- A Juju event name to emit on a charm.
"""
event_to_emit = None
try:
event_to_emit = getattr(charm.on, event_name)
except AttributeError:
logger.debug("event %s not defined for %s", event_name, charm)
# If the event is not supported by the charm implementation, do
# not error out or try to emit it. This is to support rollbacks.
if event_to_emit is not None:
args, kwargs = _get_event_args(charm, event_to_emit)
logger.debug('Emitting Juju event %s', event_name)
event_to_emit.emit(*args, **kwargs)
def _get_event_args(charm, bound_event):
event_type = bound_event.event_type
model = charm.framework.model
if issubclass(event_type, ops.charm.RelationEvent):
relation_name = os.environ['JUJU_RELATION']
relation_id = int(os.environ['JUJU_RELATION_ID'].split(':')[-1])
relation = model.get_relation(relation_name, relation_id)
else:
relation = None
remote_app_name = os.environ.get('JUJU_REMOTE_APP', '')
remote_unit_name = os.environ.get('JUJU_REMOTE_UNIT', '')
if remote_app_name or remote_unit_name:
if not remote_app_name:
if '/' not in remote_unit_name:
raise RuntimeError('invalid remote unit name: {}'.format(remote_unit_name))
remote_app_name = remote_unit_name.split('/')[0]
args = [relation, model.get_app(remote_app_name)]
if remote_unit_name:
args.append(model.get_unit(remote_unit_name))
return args, {}
elif relation:
return [relation], {}
return [], {}
def main(charm_class):
"""Setup the charm and dispatch the observed event.
The event name is based on the way this executable was called (argv[0]).
"""
charm_dir = _get_charm_dir()
model_backend = ops.model.ModelBackend()
debug = ('JUJU_DEBUG' in os.environ)
setup_root_logging(model_backend, debug=debug)
# Process the Juju event relevant to the current hook execution
# JUJU_HOOK_NAME, JUJU_FUNCTION_NAME, and JUJU_ACTION_NAME are not used
# in order to support simulation of events from debugging sessions.
#
# TODO: For Windows, when symlinks are used, this is not a valid
# method of getting an event name (see LP: #1854505).
juju_exec_path = Path(sys.argv[0])
has_dispatch = juju_exec_path.name == 'dispatch'
if has_dispatch:
# The executable was 'dispatch', which means the actual hook we want to
# run needs to be looked up in the JUJU_DISPATCH_PATH env var, where it
# should be a path relative to the charm directory (the directory that
# holds `dispatch`). If that path actually exists, we want to run that
# before continuing.
dispatch_path = juju_exec_path.parent / Path(os.environ['JUJU_DISPATCH_PATH'])
if dispatch_path.exists() and dispatch_path.resolve() != juju_exec_path.resolve():
argv = sys.argv.copy()
argv[0] = str(dispatch_path)
try:
subprocess.run(argv, check=True)
except subprocess.CalledProcessError as e:
logger.warning("hook %s exited with status %d", dispatch_path, e.returncode)
sys.exit(e.returncode)
juju_exec_path = dispatch_path
juju_event_name = juju_exec_path.name.replace('-', '_')
if juju_exec_path.parent.name == 'actions':
juju_event_name = '{}_action'.format(juju_event_name)
metadata, actions_metadata = _load_metadata(charm_dir)
meta = ops.charm.CharmMeta(metadata, actions_metadata)
unit_name = os.environ['JUJU_UNIT_NAME']
model = ops.model.Model(unit_name, meta, model_backend)
# TODO: If Juju unit agent crashes after exit(0) from the charm code
# the framework will commit the snapshot but Juju will not commit its
# operation.
charm_state_path = charm_dir / CHARM_STATE_FILE
framework = ops.framework.Framework(charm_state_path, charm_dir, meta, model)
try:
charm = charm_class(framework, None)
if not has_dispatch:
# When a charm is force-upgraded and a unit is in an error state Juju
# does not run upgrade-charm and instead runs the failed hook followed
# by config-changed. Given the nature of force-upgrading the hook setup
# code is not triggered on config-changed.
#
# 'start' event is included as Juju does not fire the install event for
# K8s charms (see LP: #1854635).
if (juju_event_name in ('install', 'start', 'upgrade_charm')
or juju_event_name.endswith('_storage_attached')):
_setup_event_links(charm_dir, charm)
# TODO: Remove the collect_metrics check below as soon as the relevant
# Juju changes are made.
#
# Skip reemission of deferred events for collect-metrics events because
# they do not have the full access to all hook tools.
if juju_event_name != 'collect_metrics':
framework.reemit()
_emit_charm_event(charm, juju_event_name)
framework.commit()
finally:
framework.close()
# Copyright 2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import pathlib
from textwrap import dedent
import typing
from ops import charm, framework, model
# OptionalYAML is something like metadata.yaml or actions.yaml. You can
# pass in a file-like object or the string directly.
OptionalYAML = typing.Optional[typing.Union[str, typing.TextIO]]
# noinspection PyProtectedMember
class Harness:
"""This class represents a way to build up the model that will drive a test suite.
The model that is created is from the viewpoint of the charm that you are testing.
Example::
harness = Harness(MyCharm)
# Do initial setup here
relation_id = harness.add_relation('db', 'postgresql')
# Now instantiate the charm to see events as the model changes
harness.begin()
harness.add_relation_unit(relation_id, 'postgresql/0')
harness.update_relation_data(relation_id, 'postgresql/0', {'key': 'val'})
# Check that charm has properly handled the relation_joined event for postgresql/0
self.assertEqual(harness.charm. ...)
Args:
charm_cls: The Charm class that you'll be testing.
meta: charm.CharmBase is a A string or file-like object containing the contents of
metadata.yaml. If not supplied, we will look for a 'metadata.yaml' file in the
parent directory of the Charm, and if not found fall back to a trivial
'name: test-charm' metadata.
actions: A string or file-like object containing the contents of
actions.yaml. If not supplied, we will look for a 'actions.yaml' file in the
parent directory of the Charm.
"""
def __init__(
self,
charm_cls: typing.Type[charm.CharmBase],
*,
meta: OptionalYAML = None,
actions: OptionalYAML = None):
# TODO: jam 2020-03-05 We probably want to take config as a parameter as well, since
# it would define the default values of config that the charm would see.
self._charm_cls = charm_cls
self._charm = None
self._charm_dir = 'no-disk-path' # this may be updated by _create_meta
self._meta = self._create_meta(meta, actions)
self._unit_name = self._meta.name + '/0'
self._framework = None
self._hooks_enabled = True
self._relation_id_counter = 0
self._backend = _TestingModelBackend(self._unit_name, self._meta)
self._model = model.Model(self._unit_name, self._meta, self._backend)
self._framework = framework.Framework(":memory:", self._charm_dir, self._meta, self._model)
@property
def charm(self) -> charm.CharmBase:
"""Return the instance of the charm class that was passed to __init__.
Note that the Charm is not instantiated until you have called
:meth:`.begin()`.
"""
return self._charm
@property
def model(self) -> model.Model:
"""Return the :class:`~ops.model.Model` that is being driven by this Harness."""
return self._model
@property
def framework(self) -> framework.Framework:
"""Return the Framework that is being driven by this Harness."""
return self._framework
def begin(self) -> None:
"""Instantiate the Charm and start handling events.
Before calling begin(), there is no Charm instance, so changes to the Model won't emit
events. You must call begin before :attr:`.charm` is valid.
"""
if self._charm is not None:
raise RuntimeError('cannot call the begin method on the harness more than once')
# The Framework adds attributes to class objects for events, etc. As such, we can't re-use
# the original class against multiple Frameworks. So create a locally defined class
# and register it.
# TODO: jam 2020-03-16 We are looking to changes this to Instance attributes instead of
# Class attributes which should clean up this ugliness. The API can stay the same
class TestEvents(self._charm_cls.on.__class__):
pass
TestEvents.__name__ = self._charm_cls.on.__class__.__name__
class TestCharm(self._charm_cls):
on = TestEvents()
# Note: jam 2020-03-01 This is so that errors in testing say MyCharm has no attribute foo,
# rather than TestCharm has no attribute foo.
TestCharm.__name__ = self._charm_cls.__name__
self._charm = TestCharm(self._framework, self._framework.meta.name)
def _create_meta(self, charm_metadata, action_metadata):
"""Create a CharmMeta object.
Handle the cases where a user doesn't supply explicit metadata snippets.
"""
filename = inspect.getfile(self._charm_cls)
charm_dir = pathlib.Path(filename).parents[1]
if charm_metadata is None:
metadata_path = charm_dir / 'metadata.yaml'
if metadata_path.is_file():
charm_metadata = metadata_path.read_text()
self._charm_dir = charm_dir
else:
# The simplest of metadata that the framework can support
charm_metadata = 'name: test-charm'
elif isinstance(charm_metadata, str):
charm_metadata = dedent(charm_metadata)
if action_metadata is None:
actions_path = charm_dir / 'actions.yaml'
if actions_path.is_file():
action_metadata = actions_path.read_text()
self._charm_dir = charm_dir
elif isinstance(action_metadata, str):
action_metadata = dedent(action_metadata)
return charm.CharmMeta.from_yaml(charm_metadata, action_metadata)
def disable_hooks(self) -> None:
"""Stop emitting hook events when the model changes.
This can be used by developers to stop changes to the model from emitting events that
the charm will react to. Call :meth:`.enable_hooks`
to re-enable them.
"""
self._hooks_enabled = False
def enable_hooks(self) -> None:
"""Re-enable hook events from charm.on when the model is changed.
By default hook events are enabled once you call :meth:`.begin`,
but if you have used :meth:`.disable_hooks`, this can be used to
enable them again.
"""
self._hooks_enabled = True
def _next_relation_id(self):
rel_id = self._relation_id_counter
self._relation_id_counter += 1
return rel_id
def add_relation(self, relation_name: str, remote_app: str) -> int:
"""Declare that there is a new relation between this app and `remote_app`.
Args:
relation_name: The relation on Charm that is being related to
remote_app: The name of the application that is being related to
Return:
The relation_id created by this add_relation.
"""
rel_id = self._next_relation_id()
self._backend._relation_ids_map.setdefault(relation_name, []).append(rel_id)
self._backend._relation_names[rel_id] = relation_name
self._backend._relation_list_map[rel_id] = []
self._backend._relation_data[rel_id] = {
remote_app: {},
self._backend.unit_name: {},
self._backend.app_name: {},
}
# Reload the relation_ids list
if self._model is not None:
self._model.relations._invalidate(relation_name)
if self._charm is None or not self._hooks_enabled:
return rel_id
relation = self._model.get_relation(relation_name, rel_id)
app = self._model.get_app(remote_app)
self._charm.on[relation_name].relation_created.emit(
relation, app)
return rel_id
def add_relation_unit(self, relation_id: int, remote_unit_name: str) -> None:
"""Add a new unit to a relation.
Example::
rel_id = harness.add_relation('db', 'postgresql')
harness.add_relation_unit(rel_id, 'postgresql/0')
This will trigger a `relation_joined` event and a `relation_changed` event.
Args:
relation_id: The integer relation identifier (as returned by add_relation).
remote_unit_name: A string representing the remote unit that is being added.
Return:
None
"""
self._backend._relation_list_map[relation_id].append(remote_unit_name)
self._backend._relation_data[relation_id][remote_unit_name] = {}
relation_name = self._backend._relation_names[relation_id]
# Make sure that the Model reloads the relation_list for this relation_id, as well as
# reloading the relation data for this unit.
if self._model is not None:
self._model.relations._invalidate(relation_name)
remote_unit = self._model.get_unit(remote_unit_name)
relation = self._model.get_relation(relation_name, relation_id)
relation.data[remote_unit]._invalidate()
if self._charm is None or not self._hooks_enabled:
return
self._charm.on[relation_name].relation_joined.emit(
relation, remote_unit.app, remote_unit)
def get_relation_data(self, relation_id: int, app_or_unit: str) -> typing.Mapping:
"""Get the relation data bucket for a single app or unit in a given relation.
This ignores all of the safety checks of who can and can't see data in relations (eg,
non-leaders can't read their own application's relation data because there are no events
that keep that data up-to-date for the unit).
Args:
relation_id: The relation whose content we want to look at.
app_or_unit: The name of the application or unit whose data we want to read
Return:
a dict containing the relation data for `app_or_unit` or None.
Raises:
KeyError: if relation_id doesn't exist
"""
return self._backend._relation_data[relation_id].get(app_or_unit, None)
def get_workload_version(self) -> str:
"""Read the workload version that was set by the unit."""
return self._backend._workload_version
def update_relation_data(
self,
relation_id: int,
app_or_unit: str,
key_values: typing.Mapping,
) -> None:
"""Update the relation data for a given unit or application in a given relation.
This also triggers the `relation_changed` event for this relation_id.
Args:
relation_id: The integer relation_id representing this relation.
app_or_unit: The unit or application name that is being updated.
This can be the local or remote application.
key_values: Each key/value will be updated in the relation data.
"""
relation_name = self._backend._relation_names[relation_id]
relation = self._model.get_relation(relation_name, relation_id)
if '/' in app_or_unit:
entity = self._model.get_unit(app_or_unit)
else:
entity = self._model.get_app(app_or_unit)
rel_data = relation.data.get(entity, None)
if rel_data is not None:
# rel_data may have cached now-stale data, so _invalidate() it.
# Note, this won't cause the data to be loaded if it wasn't already.
rel_data._invalidate()
new_values = self._backend._relation_data[relation_id][app_or_unit].copy()
for k, v in key_values.items():
if v == '':
new_values.pop(k, None)
else:
new_values[k] = v
self._backend._relation_data[relation_id][app_or_unit] = new_values
if app_or_unit == self._model.unit.name:
# No events for our own unit
return
if app_or_unit == self._model.app.name:
# updating our own app only generates an event if it is a peer relation and we
# aren't the leader
is_peer = self._meta.relations[relation_name].role == 'peers'
if not is_peer:
return
if self._model.unit.is_leader():
return
self._emit_relation_changed(relation_id, app_or_unit)
def _emit_relation_changed(self, relation_id, app_or_unit):
if self._charm is None or not self._hooks_enabled:
return
rel_name = self._backend._relation_names[relation_id]
relation = self.model.get_relation(rel_name, relation_id)
if '/' in app_or_unit:
app_name = app_or_unit.split('/')[0]
unit_name = app_or_unit
app = self.model.get_app(app_name)
unit = self.model.get_unit(unit_name)
args = (relation, app, unit)
else:
app_name = app_or_unit
app = self.model.get_app(app_name)
args = (relation, app)
self._charm.on[rel_name].relation_changed.emit(*args)
def update_config(
self,
key_values: typing.Mapping[str, str] = None,
unset: typing.Iterable[str] = (),
) -> None:
"""Update the config as seen by the charm.
This will trigger a `config_changed` event.
Args:
key_values: A Mapping of key:value pairs to update in config.
unset: An iterable of keys to remove from Config. (Note that this does
not currently reset the config values to the default defined in config.yaml.)
"""
config = self._backend._config
if key_values is not None:
for key, value in key_values.items():
config[key] = value
for key in unset:
config.pop(key, None)
# NOTE: jam 2020-03-01 Note that this sort of works "by accident". Config
# is a LazyMapping, but its _load returns a dict and this method mutates
# the dict that Config is caching. Arguably we should be doing some sort
# of charm.framework.model.config._invalidate()
if self._charm is None or not self._hooks_enabled:
return
self._charm.on.config_changed.emit()
def set_leader(self, is_leader: bool = True) -> None:
"""Set whether this unit is the leader or not.
If this charm becomes a leader then `leader_elected` will be triggered.
Args:
is_leader: True/False as to whether this unit is the leader.
"""
was_leader = self._backend._is_leader
self._backend._is_leader = is_leader
# Note: jam 2020-03-01 currently is_leader is cached at the ModelBackend level, not in
# the Model objects, so this automatically gets noticed.
if is_leader and not was_leader and self._charm is not None and self._hooks_enabled:
self._charm.on.leader_elected.emit()
class _TestingModelBackend:
"""This conforms to the interface for ModelBackend but provides canned data.
DO NOT use this class directly, it is used by `Harness`_ to drive the model.
`Harness`_ is responsible for maintaining the internal consistency of the values here,
as the only public methods of this type are for implementing ModelBackend.
"""
def __init__(self, unit_name, meta):
self.unit_name = unit_name
self.app_name = self.unit_name.split('/')[0]
self._calls = []
self._meta = meta
self._is_leader = None
self._relation_ids_map = {} # relation name to [relation_ids,...]
self._relation_names = {} # reverse map from relation_id to relation_name
self._relation_list_map = {} # relation_id: [unit_name,...]
self._relation_data = {} # {relation_id: {name: data}}
self._config = {}
self._is_leader = False
self._resources_map = {}
self._pod_spec = None
self._app_status = None
self._unit_status = None
self._workload_version = None
def relation_ids(self, relation_name):
try:
return self._relation_ids_map[relation_name]
except KeyError as e:
if relation_name not in self._meta.relations:
raise model.ModelError('{} is not a known relation'.format(relation_name)) from e
return []
def relation_list(self, relation_id):
try:
return self._relation_list_map[relation_id]
except KeyError as e:
raise model.RelationNotFoundError from e
def relation_get(self, relation_id, member_name, is_app):
if is_app and '/' in member_name:
member_name = member_name.split('/')[0]
if relation_id not in self._relation_data:
raise model.RelationNotFoundError()
return self._relation_data[relation_id][member_name].copy()
def relation_set(self, relation_id, key, value, is_app):
relation = self._relation_data[relation_id]
if is_app:
bucket_key = self.app_name
else:
bucket_key = self.unit_name
if bucket_key not in relation:
relation[bucket_key] = {}
bucket = relation[bucket_key]
if value == '':
bucket.pop(key, None)
else:
bucket[key] = value
def config_get(self):
return self._config
def is_leader(self):
return self._is_leader
def application_version_set(self, version):
self._workload_version = version
def resource_get(self, resource_name):
return self._resources_map[resource_name]
def pod_spec_set(self, spec, k8s_resources):
self._pod_spec = (spec, k8s_resources)
def status_get(self, *, is_app=False):
if is_app:
return self._app_status
else:
return self._unit_status
def status_set(self, status, message='', *, is_app=False):
if is_app:
self._app_status = (status, message)
else:
self._unit_status = (status, message)
def storage_list(self, name):
raise NotImplementedError(self.storage_list)
def storage_get(self, storage_name_id, attribute):
raise NotImplementedError(self.storage_get)
def storage_add(self, name, count=1):
raise NotImplementedError(self.storage_add)
def action_get(self):
raise NotImplementedError(self.action_get)
def action_set(self, results):
raise NotImplementedError(self.action_set)
def action_log(self, message):
raise NotImplementedError(self.action_log)
def action_fail(self, message=''):
raise NotImplementedError(self.action_fail)
def network_get(self, endpoint_name, relation_id=None):
raise NotImplementedError(self.network_get)
# Copyright 2019 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup
with open("README.md", "r") as fh:
long_description = fh.read()
setup(
name="ops",
version="0.0.1",
description="The Python library behind great charms",
long_description=long_description,
long_description_content_type="text/markdown",
license="Apache-2.0",
url="https://github.com/canonical/operator",
packages=["ops"],
classifiers=[
"Development Status :: 4 - Beta",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"License :: OSI Approved :: Apache Software License",
],
)
#!/bin/bash
case $1 in
db) echo '["db:1"]' ;;
mon) echo '["mon:2"]' ;;
ha) echo '[]' ;;
db0) echo '[]' ;;
db1) echo '["db1:4"]' ;;
db2) echo '["db2:5", "db2:6"]' ;;
*) echo '[]' ;;
esac
#!/bin/bash
fail_not_found() {
1>&2 echo "ERROR invalid value \"$1\" for option -r: relation not found"
exit 2
}
case $2 in
1) echo '["remote/0"]' ;;
2) echo '["remote/0"]' ;;
3) fail_not_found $2 ;;
4) echo '["remoteapp1/0"]' ;;
5) echo '["remoteapp1/0"]' ;;
6) echo '["remoteapp2/0"]' ;;
*) fail_not_found $2 ;;
esac
name: main
summary: A charm used for testing the basic operation of the entrypoint code.
maintainer: Dmitrii Shcherbakov <dmitrii.shcherbakov@canonical.com>
description: A charm used for testing the basic operation of the entrypoint code.
tags:
- misc
series:
- bionic
- cosmic
- disco
min-juju-version: 2.7.1
provides:
db:
interface: db
requires:
mon:
interface: monitoring
peers:
ha:
interface: cluster
subordinate: false
storage:
disks:
type: block
multiple:
range: 0-
#!/usr/bin/env python3
# Copyright 2019 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import base64
import pickle
import sys
import logging
sys.path.append('lib')
from ops.charm import CharmBase # noqa: E402 (module-level import after non-import code)
from ops.main import main # noqa: E402 (ditto)
logger = logging.getLogger()
class Charm(CharmBase):
def __init__(self, *args):
super().__init__(*args)
# This environment variable controls the test charm behavior.
charm_config = os.environ.get('CHARM_CONFIG')
if charm_config is not None:
self._charm_config = pickle.loads(base64.b64decode(charm_config))
else:
self._charm_config = {}
# TODO: refactor to use StoredState
# (this implies refactoring most of test_main.py)
self._state_file = self._charm_config.get('STATE_FILE')
try:
with open(str(self._state_file), 'rb') as f:
self._state = pickle.load(f)
except (FileNotFoundError, EOFError):
self._state = {
'on_install': [],
'on_start': [],
'on_config_changed': [],
'on_update_status': [],
'on_leader_settings_changed': [],
'on_db_relation_joined': [],
'on_mon_relation_changed': [],
'on_mon_relation_departed': [],
'on_ha_relation_broken': [],
'on_foo_bar_action': [],
'on_start_action': [],
'on_collect_metrics': [],
'on_log_critical_action': [],
'on_log_error_action': [],
'on_log_warning_action': [],
'on_log_info_action': [],
'on_log_debug_action': [],
# Observed event types per invocation. A list is used to preserve the
# order in which charm handlers have observed the events.
'observed_event_types': [],
}
self.framework.observe(self.on.install, self)
self.framework.observe(self.on.start, self)
self.framework.observe(self.on.config_changed, self)
self.framework.observe(self.on.update_status, self)
self.framework.observe(self.on.leader_settings_changed, self)
# Test relation events with endpoints from different
# sections (provides, requires, peers) as well.
self.framework.observe(self.on.db_relation_joined, self)
self.framework.observe(self.on.mon_relation_changed, self)
self.framework.observe(self.on.mon_relation_departed, self)
self.framework.observe(self.on.ha_relation_broken, self)
if self._charm_config.get('USE_ACTIONS'):
self.framework.observe(self.on.start_action, self)
self.framework.observe(self.on.foo_bar_action, self)
self.framework.observe(self.on.collect_metrics, self)
if self._charm_config.get('USE_LOG_ACTIONS'):
self.framework.observe(self.on.log_critical_action, self)
self.framework.observe(self.on.log_error_action, self)
self.framework.observe(self.on.log_warning_action, self)
self.framework.observe(self.on.log_info_action, self)
self.framework.observe(self.on.log_debug_action, self)
def _write_state(self):
"""Write state variables so that the parent process can read them.
Each invocation will override the previous state which is intentional.
"""
if self._state_file is not None:
with self._state_file.open('wb') as f:
pickle.dump(self._state, f)
def on_install(self, event):
self._state['on_install'].append(type(event))
self._state['observed_event_types'].append(type(event))
self._write_state()
def on_start(self, event):
self._state['on_start'].append(type(event))
self._state['observed_event_types'].append(type(event))
self._write_state()
def on_config_changed(self, event):
self._state['on_config_changed'].append(type(event))
self._state['observed_event_types'].append(type(event))
event.defer()
self._write_state()
def on_update_status(self, event):
self._state['on_update_status'].append(type(event))
self._state['observed_event_types'].append(type(event))
self._write_state()
def on_leader_settings_changed(self, event):
self._state['on_leader_settings_changed'].append(type(event))
self._state['observed_event_types'].append(type(event))
self._write_state()
def on_db_relation_joined(self, event):
assert event.app is not None, 'application name cannot be None for a relation-joined event'
self._state['on_db_relation_joined'].append(type(event))
self._state['observed_event_types'].append(type(event))
self._state['db_relation_joined_data'] = event.snapshot()
self._write_state()
def on_mon_relation_changed(self, event):
assert event.app is not None, (
'application name cannot be None for a relation-changed event')
if os.environ.get('JUJU_REMOTE_UNIT'):
assert event.unit is not None, (
'a unit name cannot be None for a relation-changed event'
' associated with a remote unit')
self._state['on_mon_relation_changed'].append(type(event))
self._state['observed_event_types'].append(type(event))
self._state['mon_relation_changed_data'] = event.snapshot()
self._write_state()
def on_mon_relation_departed(self, event):
assert event.app is not None, (
'application name cannot be None for a relation-departed event')
self._state['on_mon_relation_departed'].append(type(event))
self._state['observed_event_types'].append(type(event))
self._state['mon_relation_departed_data'] = event.snapshot()
self._write_state()
def on_ha_relation_broken(self, event):
assert event.app is None, (
'relation-broken events cannot have a reference to a remote application')
assert event.unit is None, (
'relation broken events cannot have a reference to a remote unit')
self._state['on_ha_relation_broken'].append(type(event))
self._state['observed_event_types'].append(type(event))
self._state['ha_relation_broken_data'] = event.snapshot()
self._write_state()
def on_start_action(self, event):
assert event.handle.kind == 'start_action', (
'event action name cannot be different from the one being handled')
self._state['on_start_action'].append(type(event))
self._state['observed_event_types'].append(type(event))
self._write_state()
def on_foo_bar_action(self, event):
assert event.handle.kind == 'foo_bar_action', (
'event action name cannot be different from the one being handled')
self._state['on_foo_bar_action'].append(type(event))
self._state['observed_event_types'].append(type(event))
self._write_state()
def on_collect_metrics(self, event):
self._state['on_collect_metrics'].append(type(event))
self._state['observed_event_types'].append(type(event))
event.add_metrics({'foo': 42}, {'bar': 4.2})
self._write_state()
def on_log_critical_action(self, event):
logger.critical('super critical')
def on_log_error_action(self, event):
logger.error('grave error')
def on_log_warning_action(self, event):
logger.warning('wise warning')
def on_log_info_action(self, event):
logger.info('useful info')
def on_log_debug_action(self, event):
logger.debug('insightful debug')
if __name__ == '__main__':
main(Charm)
#!/usr/bin/python3
# Copyright 2019 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import tempfile
import shutil
from pathlib import Path
from ops.charm import (
CharmBase,
CharmMeta,
CharmEvents,
)
from ops.framework import Framework, EventSource, EventBase
from ops.model import Model, ModelBackend
from .test_helpers import fake_script, fake_script_calls
class TestCharm(unittest.TestCase):
def setUp(self):
def restore_env(env):
os.environ.clear()
os.environ.update(env)
self.addCleanup(restore_env, os.environ.copy())
os.environ['PATH'] = "{}:{}".format(Path(__file__).parent / 'bin', os.environ['PATH'])
os.environ['JUJU_UNIT_NAME'] = 'local/0'
self.tmpdir = Path(tempfile.mkdtemp())
self.addCleanup(shutil.rmtree, str(self.tmpdir))
self.meta = CharmMeta()
class CustomEvent(EventBase):
pass
class TestCharmEvents(CharmEvents):
custom = EventSource(CustomEvent)
# Relations events are defined dynamically and modify the class attributes.
# We use a subclass temporarily to prevent these side effects from leaking.
CharmBase.on = TestCharmEvents()
def cleanup():
CharmBase.on = CharmEvents()
self.addCleanup(cleanup)
def create_framework(self):
model = Model('local/0', self.meta, ModelBackend())
framework = Framework(self.tmpdir / "framework.data", self.tmpdir, self.meta, model)
self.addCleanup(framework.close)
return framework
def test_basic(self):
class MyCharm(CharmBase):
def __init__(self, *args):
super().__init__(*args)
self.started = False
framework.observe(self.on.start, self)
def on_start(self, event):
self.started = True
events = list(MyCharm.on.events())
self.assertIn('install', events)
self.assertIn('custom', events)
framework = self.create_framework()
charm = MyCharm(framework, None)
charm.on.start.emit()
self.assertEqual(charm.started, True)
def test_helper_properties(self):
framework = self.create_framework()
class MyCharm(CharmBase):
pass
charm = MyCharm(framework, None)
self.assertEqual(charm.app, framework.model.app)
self.assertEqual(charm.unit, framework.model.unit)
self.assertEqual(charm.meta, framework.meta)
self.assertEqual(charm.charm_dir, framework.charm_dir)
def test_relation_events(self):
class MyCharm(CharmBase):
def __init__(self, *args):
super().__init__(*args)
self.seen = []
for rel in ('req1', 'req-2', 'pro1', 'pro-2', 'peer1', 'peer-2'):
# Hook up relation events to generic handler.
self.framework.observe(self.on[rel].relation_joined, self.on_any_relation)
self.framework.observe(self.on[rel].relation_changed, self.on_any_relation)
self.framework.observe(self.on[rel].relation_departed, self.on_any_relation)
self.framework.observe(self.on[rel].relation_broken, self.on_any_relation)
def on_any_relation(self, event):
assert event.relation.name == 'req1'
assert event.relation.app.name == 'remote'
self.seen.append(type(event).__name__)
# language=YAML
self.meta = CharmMeta.from_yaml(metadata='''
name: my-charm
requires:
req1:
interface: req1
req-2:
interface: req2
provides:
pro1:
interface: pro1
pro-2:
interface: pro2
peers:
peer1:
interface: peer1
peer-2:
interface: peer2
''')
charm = MyCharm(self.create_framework(), None)
rel = charm.framework.model.get_relation('req1', 1)
unit = charm.framework.model.get_unit('remote/0')
charm.on['req1'].relation_joined.emit(rel, unit)
charm.on['req1'].relation_changed.emit(rel, unit)
charm.on['req-2'].relation_changed.emit(rel, unit)
charm.on['pro1'].relation_departed.emit(rel, unit)
charm.on['pro-2'].relation_departed.emit(rel, unit)
charm.on['peer1'].relation_broken.emit(rel)
charm.on['peer-2'].relation_broken.emit(rel)
self.assertEqual(charm.seen, [
'RelationJoinedEvent',
'RelationChangedEvent',
'RelationChangedEvent',
'RelationDepartedEvent',
'RelationDepartedEvent',
'RelationBrokenEvent',
'RelationBrokenEvent',
])
def test_storage_events(self):
class MyCharm(CharmBase):
def __init__(self, *args):
super().__init__(*args)
self.seen = []
self.framework.observe(self.on['stor1'].storage_attached, self)
self.framework.observe(self.on['stor2'].storage_detaching, self)
self.framework.observe(self.on['stor3'].storage_attached, self)
self.framework.observe(self.on['stor-4'].storage_attached, self)
def on_stor1_storage_attached(self, event):
self.seen.append(type(event).__name__)
def on_stor2_storage_detaching(self, event):
self.seen.append(type(event).__name__)
def on_stor3_storage_attached(self, event):
self.seen.append(type(event).__name__)
def on_stor_4_storage_attached(self, event):
self.seen.append(type(event).__name__)
# language=YAML
self.meta = CharmMeta.from_yaml('''
name: my-charm
storage:
stor-4:
multiple:
range: 2-4
type: filesystem
stor1:
type: filesystem
stor2:
multiple:
range: "2"
type: filesystem
stor3:
multiple:
range: 2-
type: filesystem
''')
self.assertIsNone(self.meta.storages['stor1'].multiple_range)
self.assertEqual(self.meta.storages['stor2'].multiple_range, (2, 2))
self.assertEqual(self.meta.storages['stor3'].multiple_range, (2, None))
self.assertEqual(self.meta.storages['stor-4'].multiple_range, (2, 4))
charm = MyCharm(self.create_framework(), None)
charm.on['stor1'].storage_attached.emit()
charm.on['stor2'].storage_detaching.emit()
charm.on['stor3'].storage_attached.emit()
charm.on['stor-4'].storage_attached.emit()
self.assertEqual(charm.seen, [
'StorageAttachedEvent',
'StorageDetachingEvent',
'StorageAttachedEvent',
'StorageAttachedEvent',
])
@classmethod
def _get_action_test_meta(cls):
# language=YAML
return CharmMeta.from_yaml(metadata='''
name: my-charm
''', actions='''
foo-bar:
description: "Foos the bar."
params:
foo-name:
description: "A foo name to bar"
type: string
silent:
default: false
description: ""
type: boolean
required: foo-bar
title: foo-bar
start:
description: "Start the unit."
''')
def _test_action_events(self, cmd_type):
class MyCharm(CharmBase):
def __init__(self, *args):
super().__init__(*args)
framework.observe(self.on.foo_bar_action, self)
framework.observe(self.on.start_action, self)
def on_foo_bar_action(self, event):
self.seen_action_params = event.params
event.log('test-log')
event.set_results({'res': 'val with spaces'})
event.fail('test-fail')
def on_start_action(self, event):
pass
fake_script(self, cmd_type + '-get', """echo '{"foo-name": "name", "silent": true}'""")
fake_script(self, cmd_type + '-set', "")
fake_script(self, cmd_type + '-log', "")
fake_script(self, cmd_type + '-fail', "")
self.meta = self._get_action_test_meta()
os.environ['JUJU_{}_NAME'.format(cmd_type.upper())] = 'foo-bar'
framework = self.create_framework()
charm = MyCharm(framework, None)
events = list(MyCharm.on.events())
self.assertIn('foo_bar_action', events)
self.assertIn('start_action', events)
charm.on.foo_bar_action.emit()
self.assertEqual(charm.seen_action_params, {"foo-name": "name", "silent": True})
self.assertEqual(fake_script_calls(self), [
[cmd_type + '-get', '--format=json'],
[cmd_type + '-log', "test-log"],
[cmd_type + '-set', "res=val with spaces"],
[cmd_type + '-fail', "test-fail"],
])
# Make sure that action events that do not match the current context are
# not possible to emit by hand.
with self.assertRaises(RuntimeError):
charm.on.start_action.emit()
def test_action_events(self):
self._test_action_events('action')
def _test_action_event_defer_fails(self, cmd_type):
class MyCharm(CharmBase):
def __init__(self, *args):
super().__init__(*args)
framework.observe(self.on.start_action, self)
def on_start_action(self, event):
event.defer()
fake_script(self, cmd_type + '-get', """echo '{"foo-name": "name", "silent": true}'""")
self.meta = self._get_action_test_meta()
os.environ['JUJU_{}_NAME'.format(cmd_type.upper())] = 'start'
framework = self.create_framework()
charm = MyCharm(framework, None)
with self.assertRaises(RuntimeError):
charm.on.start_action.emit()
def test_action_event_defer_fails(self):
self._test_action_event_defer_fails('action')
if __name__ == "__main__":
unittest.main()
# Copyright 2019-2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import gc
import inspect
import io
import os
import shutil
import sys
import tempfile
import unittest
from unittest.mock import patch
from pathlib import Path
from ops import charm, model
from ops.framework import (
_BREAKPOINT_WELCOME_MESSAGE,
BoundStoredState,
CommitEvent,
EventBase,
ObjectEvents,
EventSource,
Framework,
Handle,
NoSnapshotError,
Object,
PreCommitEvent,
SQLiteStorage,
StoredList,
StoredState,
StoredStateData,
)
from test.test_helpers import fake_script
class TestFramework(unittest.TestCase):
def setUp(self):
self.tmpdir = Path(tempfile.mkdtemp())
self.addCleanup(shutil.rmtree, str(self.tmpdir))
default_timeout = SQLiteStorage.DB_LOCK_TIMEOUT
def timeout_cleanup():
SQLiteStorage.DB_LOCK_TIMEOUT = default_timeout
SQLiteStorage.DB_LOCK_TIMEOUT = datetime.timedelta(0)
self.addCleanup(timeout_cleanup)
def create_framework(self):
framework = Framework(self.tmpdir / "framework.data", self.tmpdir, None, None)
self.addCleanup(framework.close)
return framework
def test_handle_path(self):
cases = [
(Handle(None, "root", None), "root"),
(Handle(None, "root", "1"), "root[1]"),
(Handle(Handle(None, "root", None), "child", None), "root/child"),
(Handle(Handle(None, "root", "1"), "child", "2"), "root[1]/child[2]"),
]
for handle, path in cases:
self.assertEqual(str(handle), path)
self.assertEqual(Handle.from_path(path), handle)
def test_handle_attrs_readonly(self):
handle = Handle(None, 'kind', 'key')
with self.assertRaises(AttributeError):
handle.parent = 'foo'
with self.assertRaises(AttributeError):
handle.kind = 'foo'
with self.assertRaises(AttributeError):
handle.key = 'foo'
with self.assertRaises(AttributeError):
handle.path = 'foo'
def test_restore_unknown(self):
framework = self.create_framework()
class Foo(Object):
pass
handle = Handle(None, "a_foo", "some_key")
framework.register_type(Foo, None, handle.kind)
try:
framework.load_snapshot(handle)
except NoSnapshotError as e:
self.assertEqual(e.handle_path, str(handle))
self.assertEqual(str(e), "no snapshot data found for a_foo[some_key] object")
else:
self.fail("exception NoSnapshotError not raised")
def test_snapshot_roundtrip(self):
class Foo:
def __init__(self, handle, n):
self.handle = handle
self.my_n = n
def snapshot(self):
return {"My N!": self.my_n}
def restore(self, snapshot):
self.my_n = snapshot["My N!"] + 1
handle = Handle(None, "a_foo", "some_key")
event = Foo(handle, 1)
framework1 = self.create_framework()
framework1.register_type(Foo, None, handle.kind)
framework1.save_snapshot(event)
framework1.commit()
framework1.close()
framework2 = self.create_framework()
framework2.register_type(Foo, None, handle.kind)
event2 = framework2.load_snapshot(handle)
self.assertEqual(event2.my_n, 2)
framework2.save_snapshot(event2)
del event2
gc.collect()
event3 = framework2.load_snapshot(handle)
self.assertEqual(event3.my_n, 3)
framework2.drop_snapshot(event.handle)
framework2.commit()
framework2.close()
framework3 = self.create_framework()
framework3.register_type(Foo, None, handle.kind)
self.assertRaises(NoSnapshotError, framework3.load_snapshot, handle)
def test_simple_event_observer(self):
framework = self.create_framework()
class MyEvent(EventBase):
pass
class MyNotifier(Object):
foo = EventSource(MyEvent)
bar = EventSource(MyEvent)
baz = EventSource(MyEvent)
class MyObserver(Object):
def __init__(self, parent, key):
super().__init__(parent, key)
self.seen = []
def on_any(self, event):
self.seen.append("on_any:" + event.handle.kind)
def on_foo(self, event):
self.seen.append("on_foo:" + event.handle.kind)
pub = MyNotifier(framework, "1")
obs = MyObserver(framework, "1")
framework.observe(pub.foo, obs.on_any)
framework.observe(pub.bar, obs.on_any)
framework.observe(pub.foo, obs) # Method name defaults to on_<event kind>.
try:
framework.observe(pub.baz, obs)
except RuntimeError as e:
self.assertEqual(
str(e),
'Observer method not provided explicitly'
' and MyObserver type has no "on_baz" method')
else:
self.fail("RuntimeError not raised")
pub.foo.emit()
pub.bar.emit()
self.assertEqual(obs.seen, ["on_any:foo", "on_foo:foo", "on_any:bar"])
def test_bad_sig_observer(self):
class MyEvent(EventBase):
pass
class MyNotifier(Object):
foo = EventSource(MyEvent)
bar = EventSource(MyEvent)
baz = EventSource(MyEvent)
qux = EventSource(MyEvent)
class MyObserver(Object):
def on_foo(self):
assert False, 'should not be reached'
def on_bar(self, event, extra):
assert False, 'should not be reached'
def on_baz(self, event, extra=None, *, k):
assert False, 'should not be reached'
def on_qux(self, event, extra=None):
assert False, 'should not be reached'
framework = self.create_framework()
pub = MyNotifier(framework, "pub")
obs = MyObserver(framework, "obs")
with self.assertRaises(TypeError):
framework.observe(pub.foo, obs)
with self.assertRaises(TypeError):
framework.observe(pub.bar, obs)
with self.assertRaises(TypeError):
framework.observe(pub.baz, obs)
framework.observe(pub.qux, obs)
def test_on_pre_commit_emitted(self):
framework = self.create_framework()
class PreCommitObserver(Object):
_stored = StoredState()
def __init__(self, parent, key):
super().__init__(parent, key)
self.seen = []
self._stored.myinitdata = 40
def on_pre_commit(self, event):
self._stored.myinitdata = 41
self._stored.mydata = 42
self.seen.append(type(event))
def on_commit(self, event):
# Modifications made here will not be persisted.
self._stored.myinitdata = 42
self._stored.mydata = 43
self._stored.myotherdata = 43
self.seen.append(type(event))
obs = PreCommitObserver(framework, None)
framework.observe(framework.on.pre_commit, obs.on_pre_commit)
framework.commit()
self.assertEqual(obs._stored.myinitdata, 41)
self.assertEqual(obs._stored.mydata, 42)
self.assertTrue(obs.seen, [PreCommitEvent, CommitEvent])
framework.close()
other_framework = self.create_framework()
new_obs = PreCommitObserver(other_framework, None)
self.assertEqual(obs._stored.myinitdata, 41)
self.assertEqual(new_obs._stored.mydata, 42)
with self.assertRaises(AttributeError):
new_obs._stored.myotherdata
def test_defer_and_reemit(self):
framework = self.create_framework()
class MyEvent(EventBase):
pass
class MyNotifier1(Object):
a = EventSource(MyEvent)
b = EventSource(MyEvent)
class MyNotifier2(Object):
c = EventSource(MyEvent)
class MyObserver(Object):
def __init__(self, parent, key):
super().__init__(parent, key)
self.seen = []
self.done = {}
def on_any(self, event):
self.seen.append(event.handle.kind)
if not self.done.get(event.handle.kind):
event.defer()
pub1 = MyNotifier1(framework, "1")
pub2 = MyNotifier2(framework, "1")
obs1 = MyObserver(framework, "1")
obs2 = MyObserver(framework, "2")
framework.observe(pub1.a, obs1.on_any)
framework.observe(pub1.b, obs1.on_any)
framework.observe(pub1.a, obs2.on_any)
framework.observe(pub1.b, obs2.on_any)
framework.observe(pub2.c, obs2.on_any)
pub1.a.emit()
pub1.b.emit()
pub2.c.emit()
# Events remain stored because they were deferred.
ev_a_handle = Handle(pub1, "a", "1")
framework.load_snapshot(ev_a_handle)
ev_b_handle = Handle(pub1, "b", "2")
framework.load_snapshot(ev_b_handle)
ev_c_handle = Handle(pub2, "c", "3")
framework.load_snapshot(ev_c_handle)
# make sure the objects are gone before we reemit them
gc.collect()
framework.reemit()
obs1.done["a"] = True
obs2.done["b"] = True
framework.reemit()
framework.reemit()
obs1.done["b"] = True
obs2.done["a"] = True
framework.reemit()
obs2.done["c"] = True
framework.reemit()
framework.reemit()
framework.reemit()
self.assertEqual(" ".join(obs1.seen), "a b a b a b b b")
self.assertEqual(" ".join(obs2.seen), "a b c a b c a b c a c a c c")
# Now the event objects must all be gone from storage.
self.assertRaises(NoSnapshotError, framework.load_snapshot, ev_a_handle)
self.assertRaises(NoSnapshotError, framework.load_snapshot, ev_b_handle)
self.assertRaises(NoSnapshotError, framework.load_snapshot, ev_c_handle)
def test_custom_event_data(self):
framework = self.create_framework()
class MyEvent(EventBase):
def __init__(self, handle, n):
super().__init__(handle)
self.my_n = n
def snapshot(self):
return {"My N!": self.my_n}
def restore(self, snapshot):
super().restore(snapshot)
self.my_n = snapshot["My N!"] + 1
class MyNotifier(Object):
foo = EventSource(MyEvent)
class MyObserver(Object):
def __init__(self, parent, key):
super().__init__(parent, key)
self.seen = []
def on_foo(self, event):
self.seen.append("on_foo:{}={}".format(event.handle.kind, event.my_n))
event.defer()
pub = MyNotifier(framework, "1")
obs = MyObserver(framework, "1")
framework.observe(pub.foo, obs)
pub.foo.emit(1)
framework.reemit()
# Two things being checked here:
#
# 1. There's a restore roundtrip before the event is first observed.
# That means the data is safe before it's ever seen, and the
# roundtrip logic is tested under normal circumstances.
#
# 2. The renotification restores from the pristine event, not
# from the one modified during the first restore (otherwise
# we'd get a foo=3).
#
self.assertEqual(obs.seen, ["on_foo:foo=2", "on_foo:foo=2"])
def test_weak_observer(self):
framework = self.create_framework()
observed_events = []
class MyEvent(EventBase):
pass
class MyEvents(ObjectEvents):
foo = EventSource(MyEvent)
class MyNotifier(Object):
on = MyEvents()
class MyObserver(Object):
def on_foo(self, event):
observed_events.append("foo")
pub = MyNotifier(framework, "1")
obs = MyObserver(framework, "2")
framework.observe(pub.on.foo, obs)
pub.on.foo.emit()
self.assertEqual(observed_events, ["foo"])
# Now delete the observer, and note that when we emit the event, it
# doesn't update the local slice again
del obs
gc.collect()
pub.on.foo.emit()
self.assertEqual(observed_events, ["foo"])
def test_forget_and_multiple_objects(self):
framework = self.create_framework()
class MyObject(Object):
pass
o1 = MyObject(framework, "path")
# Creating a second object at the same path should fail with RuntimeError
with self.assertRaises(RuntimeError):
o2 = MyObject(framework, "path")
# Unless we _forget the object first
framework._forget(o1)
o2 = MyObject(framework, "path")
self.assertEqual(o1.handle.path, o2.handle.path)
# Deleting the tracked object should also work
del o2
gc.collect()
o3 = MyObject(framework, "path")
self.assertEqual(o1.handle.path, o3.handle.path)
framework.close()
# Or using a second framework
framework_copy = self.create_framework()
o_copy = MyObject(framework_copy, "path")
self.assertEqual(o1.handle.path, o_copy.handle.path)
def test_forget_and_multiple_objects_with_load_snapshot(self):
framework = self.create_framework()
class MyObject(Object):
def __init__(self, parent, name):
super().__init__(parent, name)
self.value = name
def snapshot(self):
return self.value
def restore(self, value):
self.value = value
framework.register_type(MyObject, None, MyObject.handle_kind)
o1 = MyObject(framework, "path")
framework.save_snapshot(o1)
framework.commit()
o_handle = o1.handle
del o1
gc.collect()
o2 = framework.load_snapshot(o_handle)
# Trying to load_snapshot a second object at the same path should fail with RuntimeError
with self.assertRaises(RuntimeError):
framework.load_snapshot(o_handle)
# Unless we _forget the object first
framework._forget(o2)
o3 = framework.load_snapshot(o_handle)
self.assertEqual(o2.value, o3.value)
# A loaded object also prevents direct creation of an object
with self.assertRaises(RuntimeError):
MyObject(framework, "path")
framework.close()
# But we can create an object, or load a snapshot in a copy of the framework
framework_copy1 = self.create_framework()
o_copy1 = MyObject(framework_copy1, "path")
self.assertEqual(o_copy1.value, "path")
framework_copy1.close()
framework_copy2 = self.create_framework()
framework_copy2.register_type(MyObject, None, MyObject.handle_kind)
o_copy2 = framework_copy2.load_snapshot(o_handle)
self.assertEqual(o_copy2.value, "path")
def test_events_base(self):
framework = self.create_framework()
class MyEvent(EventBase):
pass
class MyEvents(ObjectEvents):
foo = EventSource(MyEvent)
bar = EventSource(MyEvent)
class MyNotifier(Object):
on = MyEvents()
class MyObserver(Object):
def __init__(self, parent, key):
super().__init__(parent, key)
self.seen = []
def on_foo(self, event):
self.seen.append("on_foo:{}".format(event.handle.kind))
event.defer()
def on_bar(self, event):
self.seen.append("on_bar:{}".format(event.handle.kind))
pub = MyNotifier(framework, "1")
obs = MyObserver(framework, "1")
# Confirm that temporary persistence of BoundEvents doesn't cause errors,
# and that events can be observed.
for bound_event in [pub.on.foo, pub.on.bar]:
framework.observe(bound_event, obs)
# Confirm that events can be emitted and seen.
pub.on.foo.emit()
self.assertEqual(obs.seen, ["on_foo:foo"])
def test_conflicting_event_attributes(self):
class MyEvent(EventBase):
pass
event = EventSource(MyEvent)
class MyEvents(ObjectEvents):
foo = event
with self.assertRaises(RuntimeError) as cm:
class OtherEvents(ObjectEvents):
foo = event
self.assertEqual(
str(cm.exception),
"EventSource(MyEvent) reused as MyEvents.foo and OtherEvents.foo")
with self.assertRaises(RuntimeError) as cm:
class MyNotifier(Object):
on = MyEvents()
bar = event
self.assertEqual(
str(cm.exception),
"EventSource(MyEvent) reused as MyEvents.foo and MyNotifier.bar")
def test_reemit_ignores_unknown_event_type(self):
# The event type may have been gone for good, and nobody cares,
# so this shouldn't be an error scenario.
framework = self.create_framework()
class MyEvent(EventBase):
pass
class MyNotifier(Object):
foo = EventSource(MyEvent)
class MyObserver(Object):
def __init__(self, parent, key):
super().__init__(parent, key)
self.seen = []
def on_foo(self, event):
self.seen.append(event.handle)
event.defer()
pub = MyNotifier(framework, "1")
obs = MyObserver(framework, "1")
framework.observe(pub.foo, obs)
pub.foo.emit()
event_handle = obs.seen[0]
self.assertEqual(event_handle.kind, "foo")
framework.commit()
framework.close()
framework_copy = self.create_framework()
# No errors on missing event types here.
framework_copy.reemit()
# Register the type and check that the event is gone from storage.
framework_copy.register_type(MyEvent, event_handle.parent, event_handle.kind)
self.assertRaises(NoSnapshotError, framework_copy.load_snapshot, event_handle)
def test_auto_register_event_types(self):
framework = self.create_framework()
class MyFoo(EventBase):
pass
class MyBar(EventBase):
pass
class MyEvents(ObjectEvents):
foo = EventSource(MyFoo)
class MyNotifier(Object):
on = MyEvents()
bar = EventSource(MyBar)
class MyObserver(Object):
def __init__(self, parent, key):
super().__init__(parent, key)
self.seen = []
def on_foo(self, event):
self.seen.append("on_foo:{}:{}".format(type(event).__name__, event.handle.kind))
event.defer()
def on_bar(self, event):
self.seen.append("on_bar:{}:{}".format(type(event).__name__, event.handle.kind))
event.defer()
pub = MyNotifier(framework, "1")
obs = MyObserver(framework, "1")
pub.on.foo.emit()
pub.bar.emit()
framework.observe(pub.on.foo, obs)
framework.observe(pub.bar, obs)
pub.on.foo.emit()
pub.bar.emit()
self.assertEqual(obs.seen, ["on_foo:MyFoo:foo", "on_bar:MyBar:bar"])
def test_dynamic_event_types(self):
framework = self.create_framework()
class MyEventsA(ObjectEvents):
handle_kind = 'on_a'
class MyEventsB(ObjectEvents):
handle_kind = 'on_b'
class MyNotifier(Object):
on_a = MyEventsA()
on_b = MyEventsB()
class MyObserver(Object):
def __init__(self, parent, key):
super().__init__(parent, key)
self.seen = []
def on_foo(self, event):
self.seen.append("on_foo:{}:{}".format(type(event).__name__, event.handle.kind))
event.defer()
def on_bar(self, event):
self.seen.append("on_bar:{}:{}".format(type(event).__name__, event.handle.kind))
event.defer()
pub = MyNotifier(framework, "1")
obs = MyObserver(framework, "1")
class MyFoo(EventBase):
pass
class MyBar(EventBase):
pass
class DeadBeefEvent(EventBase):
pass
class NoneEvent(EventBase):
pass
pub.on_a.define_event("foo", MyFoo)
pub.on_b.define_event("bar", MyBar)
framework.observe(pub.on_a.foo, obs)
framework.observe(pub.on_b.bar, obs)
pub.on_a.foo.emit()
pub.on_b.bar.emit()
self.assertEqual(obs.seen, ["on_foo:MyFoo:foo", "on_bar:MyBar:bar"])
# Definitions remained local to the specific type.
self.assertRaises(AttributeError, lambda: pub.on_a.bar)
self.assertRaises(AttributeError, lambda: pub.on_b.foo)
# Try to use an event name which is not a valid python identifier.
with self.assertRaises(RuntimeError):
pub.on_a.define_event("dead-beef", DeadBeefEvent)
# Try to use a python keyword for an event name.
with self.assertRaises(RuntimeError):
pub.on_a.define_event("None", NoneEvent)
# Try to override an existing attribute.
with self.assertRaises(RuntimeError):
pub.on_a.define_event("foo", MyFoo)
def test_event_key_roundtrip(self):
class MyEvent(EventBase):
def __init__(self, handle, value):
super().__init__(handle)
self.value = value
def snapshot(self):
return self.value
def restore(self, value):
self.value = value
class MyNotifier(Object):
foo = EventSource(MyEvent)
class MyObserver(Object):
has_deferred = False
def __init__(self, parent, key):
super().__init__(parent, key)
self.seen = []
def on_foo(self, event):
self.seen.append((event.handle.key, event.value))
# Only defer the first event and once.
if not MyObserver.has_deferred:
event.defer()
MyObserver.has_deferred = True
framework1 = self.create_framework()
pub1 = MyNotifier(framework1, "pub")
obs1 = MyObserver(framework1, "obs")
framework1.observe(pub1.foo, obs1)
pub1.foo.emit('first')
self.assertEqual(obs1.seen, [('1', 'first')])
framework1.commit()
framework1.close()
del framework1
framework2 = self.create_framework()
pub2 = MyNotifier(framework2, "pub")
obs2 = MyObserver(framework2, "obs")
framework2.observe(pub2.foo, obs2)
pub2.foo.emit('second')
framework2.reemit()
# First observer didn't get updated, since framework it was bound to is gone.
self.assertEqual(obs1.seen, [('1', 'first')])
# Second observer saw the new event plus the reemit of the first event.
# (The event key goes up by 2 due to the pre-commit and commit events.)
self.assertEqual(obs2.seen, [('4', 'second'), ('1', 'first')])
def test_helper_properties(self):
framework = self.create_framework()
framework.model = 'test-model'
framework.meta = 'test-meta'
my_obj = Object(framework, 'my_obj')
self.assertEqual(my_obj.model, framework.model)
def test_ban_concurrent_frameworks(self):
f = self.create_framework()
with self.assertRaises(Exception) as cm:
self.create_framework()
self.assertIn('database is locked', str(cm.exception))
f.close()
def test_snapshot_saving_restricted_to_simple_types(self):
# this can not be saved, as it has not simple types!
to_be_saved = {"bar": TestFramework}
class FooEvent(EventBase):
def snapshot(self):
return to_be_saved
handle = Handle(None, "a_foo", "some_key")
event = FooEvent(handle)
framework = self.create_framework()
framework.register_type(FooEvent, None, handle.kind)
with self.assertRaises(ValueError) as cm:
framework.save_snapshot(event)
expected = (
"unable to save the data for FooEvent, it must contain only simple types: "
"{'bar': <class 'test.test_framework.TestFramework'>}")
self.assertEqual(str(cm.exception), expected)
class TestStoredState(unittest.TestCase):
def setUp(self):
self.tmpdir = Path(tempfile.mkdtemp())
self.addCleanup(shutil.rmtree, str(self.tmpdir))
def create_framework(self, cls=Framework):
framework = cls(self.tmpdir / "framework.data", self.tmpdir, None, None)
self.addCleanup(framework.close)
return framework
def test_basic_state_storage(self):
class SomeObject(Object):
_stored = StoredState()
self._stored_state_tests(SomeObject)
def test_straight_subclass(self):
class SomeObject(Object):
_stored = StoredState()
class Sub(SomeObject):
pass
self._stored_state_tests(Sub)
def test_straight_sub_subclass(self):
class SomeObject(Object):
_stored = StoredState()
class Sub(SomeObject):
pass
class SubSub(SomeObject):
pass
self._stored_state_tests(SubSub)
def test_two_subclasses(self):
class SomeObject(Object):
_stored = StoredState()
class SubA(SomeObject):
pass
class SubB(SomeObject):
pass
self._stored_state_tests(SubA)
self._stored_state_tests(SubB)
def test_the_crazy_thing(self):
class NoState(Object):
pass
class StatedObject(NoState):
_stored = StoredState()
class Sibling(NoState):
pass
class FinalChild(StatedObject, Sibling):
pass
self._stored_state_tests(FinalChild)
def _stored_state_tests(self, cls):
framework = self.create_framework()
obj = cls(framework, "1")
try:
obj._stored.foo
except AttributeError as e:
self.assertEqual(str(e), "attribute 'foo' is not stored")
else:
self.fail("AttributeError not raised")
try:
obj._stored.on = "nonono"
except AttributeError as e:
self.assertEqual(str(e), "attribute 'on' is reserved and cannot be set")
else:
self.fail("AttributeError not raised")
obj._stored.foo = 41
obj._stored.foo = 42
obj._stored.bar = "s"
obj._stored.baz = 4.2
obj._stored.bing = True
self.assertEqual(obj._stored.foo, 42)
framework.commit()
# This won't be committed, and should not be seen.
obj._stored.foo = 43
framework.close()
# Since this has the same absolute object handle, it will get its state back.
framework_copy = self.create_framework()
obj_copy = cls(framework_copy, "1")
self.assertEqual(obj_copy._stored.foo, 42)
self.assertEqual(obj_copy._stored.bar, "s")
self.assertEqual(obj_copy._stored.baz, 4.2)
self.assertEqual(obj_copy._stored.bing, True)
framework_copy.close()
def test_two_subclasses_no_conflicts(self):
class Base(Object):
_stored = StoredState()
class SubA(Base):
pass
class SubB(Base):
pass
framework = self.create_framework()
a = SubA(framework, None)
b = SubB(framework, None)
z = Base(framework, None)
a._stored.foo = 42
b._stored.foo = "hello"
z._stored.foo = {1}
framework.commit()
framework.close()
framework2 = self.create_framework()
a2 = SubA(framework2, None)
b2 = SubB(framework2, None)
z2 = Base(framework2, None)
self.assertEqual(a2._stored.foo, 42)
self.assertEqual(b2._stored.foo, "hello")
self.assertEqual(z2._stored.foo, {1})
def test_two_names_one_state(self):
class Mine(Object):
_stored = StoredState()
_stored2 = _stored
framework = self.create_framework()
obj = Mine(framework, None)
with self.assertRaises(RuntimeError):
obj._stored.foo = 42
with self.assertRaises(RuntimeError):
obj._stored2.foo = 42
framework.close()
# make sure we're not changing the object on failure
self.assertNotIn("_stored", obj.__dict__)
self.assertNotIn("_stored2", obj.__dict__)
def test_same_name_two_classes(self):
class Base(Object):
pass
class A(Base):
_stored = StoredState()
class B(Base):
_stored = A._stored
framework = self.create_framework()
a = A(framework, None)
b = B(framework, None)
# NOTE it's the second one that actually triggers the
# exception, but that's an implementation detail
a._stored.foo = 42
with self.assertRaises(RuntimeError):
b._stored.foo = "xyzzy"
framework.close()
# make sure we're not changing the object on failure
self.assertNotIn("_stored", b.__dict__)
def test_mutable_types_invalid(self):
framework = self.create_framework()
class SomeObject(Object):
_stored = StoredState()
obj = SomeObject(framework, '1')
try:
class CustomObject:
pass
obj._stored.foo = CustomObject()
except AttributeError as e:
self.assertEqual(
str(e),
"attribute 'foo' cannot be a CustomObject: must be int/float/dict/list/etc")
else:
self.fail('AttributeError not raised')
framework.commit()
def test_mutable_types(self):
# Test and validation functions in a list of 2-tuples.
# Assignment and keywords like del are not supported in lambdas
# so functions are used instead.
test_operations = [(
lambda: {}, # Operand A.
None, # Operand B.
{}, # Expected result.
lambda a, b: None, # Operation to perform.
lambda res, expected_res: self.assertEqual(res, expected_res) # Validation to perform.
), (
lambda: {},
{'a': {}},
{'a': {}},
lambda a, b: a.update(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: {'a': {}},
{'b': 'c'},
{'a': {'b': 'c'}},
lambda a, b: a['a'].update(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: {'a': {'b': 'c'}},
{'d': 'e'},
{'a': {'b': 'c', 'd': 'e'}},
lambda a, b: a['a'].update(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: {'a': {'b': 'c', 'd': 'e'}},
'd',
{'a': {'b': 'c'}},
lambda a, b: a['a'].pop(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: {'s': set()},
'a',
{'s': {'a'}},
lambda a, b: a['s'].add(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: {'s': {'a'}},
'a',
{'s': set()},
lambda a, b: a['s'].discard(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: [],
None,
[],
lambda a, b: None,
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: [],
'a',
['a'],
lambda a, b: a.append(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: ['a'],
['c'],
['a', ['c']],
lambda a, b: a.append(b),
lambda res, expected_res: (
self.assertEqual(res, expected_res),
self.assertIsInstance(res[1], StoredList),
)
), (
lambda: ['a', ['c']],
'b',
['b', 'a', ['c']],
lambda a, b: a.insert(0, b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: ['b', 'a', ['c']],
['d'],
['b', ['d'], 'a', ['c']],
lambda a, b: a.insert(1, b),
lambda res, expected_res: (
self.assertEqual(res, expected_res),
self.assertIsInstance(res[1], StoredList)
),
), (
lambda: ['b', 'a', ['c']],
['d'],
['b', ['d'], ['c']],
# a[1] = b
lambda a, b: a.__setitem__(1, b),
lambda res, expected_res: (
self.assertEqual(res, expected_res),
self.assertIsInstance(res[1], StoredList)
),
), (
lambda: ['b', ['d'], 'a', ['c']],
0,
[['d'], 'a', ['c']],
lambda a, b: a.pop(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: [['d'], 'a', ['c']],
['d'],
['a', ['c']],
lambda a, b: a.remove(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: ['a', ['c']],
'd',
['a', ['c', 'd']],
lambda a, b: a[1].append(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: ['a', ['c', 'd']],
1,
['a', ['c']],
lambda a, b: a[1].pop(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: ['a', ['c']],
'd',
['a', ['c', 'd']],
lambda a, b: a[1].insert(1, b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: ['a', ['c', 'd']],
'd',
['a', ['c']],
lambda a, b: a[1].remove(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: set(),
None,
set(),
lambda a, b: None,
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: set(),
'a',
set(['a']),
lambda a, b: a.add(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: set(['a']),
'a',
set(),
lambda a, b: a.discard(b),
lambda res, expected_res: self.assertEqual(res, expected_res)
), (
lambda: set(),
{'a'},
set(),
# Nested sets are not allowed as sets themselves are not hashable.
lambda a, b: self.assertRaises(TypeError, a.add, b),
lambda res, expected_res: self.assertEqual(res, expected_res)
)]
class SomeObject(Object):
_stored = StoredState()
class WrappedFramework(Framework):
def __init__(self, data_path, charm_dir, meta, model):
super().__init__(data_path, charm_dir, meta, model)
self.snapshots = []
def save_snapshot(self, value):
if value.handle.path == 'SomeObject[1]/StoredStateData[_stored]':
self.snapshots.append((type(value), value.snapshot()))
return super().save_snapshot(value)
# Validate correctness of modification operations.
for get_a, b, expected_res, op, validate_op in test_operations:
framework = self.create_framework(cls=WrappedFramework)
obj = SomeObject(framework, '1')
obj._stored.a = get_a()
self.assertTrue(isinstance(obj._stored, BoundStoredState))
op(obj._stored.a, b)
validate_op(obj._stored.a, expected_res)
obj._stored.a = get_a()
framework.commit()
# We should see an update for initializing a
self.assertEqual(framework.snapshots, [
(StoredStateData, {'a': get_a()}),
])
del obj
gc.collect()
obj_copy1 = SomeObject(framework, '1')
self.assertEqual(obj_copy1._stored.a, get_a())
op(obj_copy1._stored.a, b)
validate_op(obj_copy1._stored.a, expected_res)
framework.commit()
framework.close()
framework_copy = self.create_framework(cls=WrappedFramework)
obj_copy2 = SomeObject(framework_copy, '1')
validate_op(obj_copy2._stored.a, expected_res)
# Commit saves the pre-commit and commit events, and the framework
# event counter, but shouldn't update the stored state of my object
framework.snapshots.clear()
framework_copy.commit()
self.assertEqual(framework_copy.snapshots, [])
framework_copy.close()
def test_comparison_operations(self):
test_operations = [(
{"1"}, # Operand A.
{"1", "2"}, # Operand B.
lambda a, b: a < b, # Operation to test.
True, # Result of op(A, B).
False, # Result of op(B, A).
), (
{"1"},
{"1", "2"},
lambda a, b: a > b,
False,
True
), (
# Empty set comparison.
set(),
set(),
lambda a, b: a == b,
True,
True
), (
{"a", "c"},
{"c", "a"},
lambda a, b: a == b,
True,
True
), (
dict(),
dict(),
lambda a, b: a == b,
True,
True
), (
{"1": "2"},
{"1": "2"},
lambda a, b: a == b,
True,
True
), (
{"1": "2"},
{"1": "3"},
lambda a, b: a == b,
False,
False
), (
[],
[],
lambda a, b: a == b,
True,
True
), (
[1, 2],
[1, 2],
lambda a, b: a == b,
True,
True
), (
[1, 2, 5, 6],
[1, 2, 5, 8, 10],
lambda a, b: a <= b,
True,
False
), (
[1, 2, 5, 6],
[1, 2, 5, 8, 10],
lambda a, b: a < b,
True,
False
), (
[1, 2, 5, 8],
[1, 2, 5, 6, 10],
lambda a, b: a > b,
True,
False
), (
[1, 2, 5, 8],
[1, 2, 5, 6, 10],
lambda a, b: a >= b,
True,
False
)]
class SomeObject(Object):
_stored = StoredState()
framework = self.create_framework()
for i, (a, b, op, op_ab, op_ba) in enumerate(test_operations):
obj = SomeObject(framework, str(i))
obj._stored.a = a
self.assertEqual(op(obj._stored.a, b), op_ab)
self.assertEqual(op(b, obj._stored.a), op_ba)
def test_set_operations(self):
test_operations = [(
{"1"}, # A set to test an operation against (other_set).
lambda a, b: a | b, # An operation to test.
{"1", "a", "b"}, # The expected result of operation(obj._stored.set, other_set).
{"1", "a", "b"} # The expected result of operation(other_set, obj._stored.set).
), (
{"a", "c"},
lambda a, b: a - b,
{"b"},
{"c"}
), (
{"a", "c"},
lambda a, b: a & b,
{"a"},
{"a"}
), (
{"a", "c", "d"},
lambda a, b: a ^ b,
{"b", "c", "d"},
{"b", "c", "d"}
), (
set(),
lambda a, b: set(a),
{"a", "b"},
set()
)]
class SomeObject(Object):
_stored = StoredState()
framework = self.create_framework()
# Validate that operations between StoredSet and built-in sets
# only result in built-in sets being returned.
# Make sure that commutativity is preserved and that the
# original sets are not changed or used as a result.
for i, (variable_operand, operation, ab_res, ba_res) in enumerate(test_operations):
obj = SomeObject(framework, str(i))
obj._stored.set = {"a", "b"}
for a, b, expected in [
(obj._stored.set, variable_operand, ab_res),
(variable_operand, obj._stored.set, ba_res)]:
old_a = set(a)
old_b = set(b)
result = operation(a, b)
self.assertEqual(result, expected)
# Common sanity checks
self.assertIsNot(obj._stored.set._under, result)
self.assertIsNot(result, a)
self.assertIsNot(result, b)
self.assertEqual(a, old_a)
self.assertEqual(b, old_b)
def test_set_default(self):
framework = self.create_framework()
class StatefulObject(Object):
_stored = StoredState()
parent = StatefulObject(framework, 'key')
parent._stored.set_default(foo=1)
self.assertEqual(parent._stored.foo, 1)
parent._stored.set_default(foo=2)
# foo was already set, so it doesn't get replaced
self.assertEqual(parent._stored.foo, 1)
parent._stored.set_default(foo=3, bar=4)
self.assertEqual(parent._stored.foo, 1)
self.assertEqual(parent._stored.bar, 4)
# reloading the state still leaves things at the default values
framework.commit()
del parent
parent = StatefulObject(framework, 'key')
parent._stored.set_default(foo=5, bar=6)
self.assertEqual(parent._stored.foo, 1)
self.assertEqual(parent._stored.bar, 4)
# TODO: jam 2020-01-30 is there a clean way to tell that
# parent._stored._data.dirty is False?
def create_model(testcase):
"""Create a Model object."""
unit_name = 'myapp/0'
patcher = patch.dict(os.environ, {'JUJU_UNIT_NAME': unit_name})
patcher.start()
testcase.addCleanup(patcher.stop)
backend = model.ModelBackend()
meta = charm.CharmMeta()
test_model = model.Model('myapp/0', meta, backend)
return test_model
def create_framework(testcase, model=None):
"""Create a Framework object."""
framework = Framework(":memory:", charm_dir='non-existant', meta=None, model=model)
testcase.addCleanup(framework.close)
return framework
class GenericObserver(Object):
"""Generic observer for the tests."""
def __init__(self, parent, key):
super().__init__(parent, key)
self.called = False
def callback_method(self, event):
"""Set the instance .called to True."""
self.called = True
@patch('sys.stderr', new_callable=io.StringIO)
class BreakpointTests(unittest.TestCase):
def test_ignored(self, fake_stderr):
# It doesn't do anything really unless proper environment is there.
with patch.dict(os.environ):
os.environ.pop('JUJU_DEBUG_AT', None)
framework = create_framework(self)
with patch('pdb.Pdb.set_trace') as mock:
framework.breakpoint()
self.assertEqual(mock.call_count, 0)
self.assertEqual(fake_stderr.getvalue(), "")
def test_pdb_properly_called(self, fake_stderr):
# The debugger needs to leave the user in the frame where the breakpoint is executed,
# which for the test is the frame we're calling it here in the test :).
with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'all'}):
framework = create_framework(self)
with patch('pdb.Pdb.set_trace') as mock:
this_frame = inspect.currentframe()
framework.breakpoint()
self.assertEqual(mock.call_count, 1)
self.assertEqual(mock.call_args, ((this_frame,), {}))
def test_welcome_message(self, fake_stderr):
# Check that an initial message is shown to the user when code is interrupted.
with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'all'}):
framework = create_framework(self)
with patch('pdb.Pdb.set_trace'):
framework.breakpoint()
self.assertEqual(fake_stderr.getvalue(), _BREAKPOINT_WELCOME_MESSAGE)
def test_welcome_message_not_multiple(self, fake_stderr):
# Check that an initial message is NOT shown twice if the breakpoint is exercised
# twice in the same run.
with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'all'}):
framework = create_framework(self)
with patch('pdb.Pdb.set_trace'):
framework.breakpoint()
self.assertEqual(fake_stderr.getvalue(), _BREAKPOINT_WELCOME_MESSAGE)
framework.breakpoint()
self.assertEqual(fake_stderr.getvalue(), _BREAKPOINT_WELCOME_MESSAGE)
def test_builtin_breakpoint_hooked(self, fake_stderr):
# Verify that the proper hook is set.
with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'all'}):
create_framework(self) # creating the framework setups the hook
with patch('pdb.Pdb.set_trace') as mock:
# Calling through sys, not breakpoint() directly, so we can run the
# tests with Py < 3.7.
sys.breakpointhook()
self.assertEqual(mock.call_count, 1)
def test_breakpoint_names(self, fake_stderr):
framework = create_framework(self)
# Name rules:
# - must start and end with lowercase alphanumeric characters
# - only contain lowercase alphanumeric characters, or the hyphen "-"
good_names = [
'foobar',
'foo-bar-baz',
'foo-------bar',
'foo123',
'778',
'77-xx',
'a-b',
'ab',
'x',
]
for name in good_names:
with self.subTest(name=name):
framework.breakpoint(name)
bad_names = [
'',
'.',
'-',
'...foo',
'foo.bar',
'bar--'
'FOO',
'FooBar',
'foo bar',
'foo_bar',
'/foobar',
'break-here-☚',
]
msg = 'breakpoint names must look like "foo" or "foo-bar"'
for name in bad_names:
with self.subTest(name=name):
with self.assertRaises(ValueError) as cm:
framework.breakpoint(name)
self.assertEqual(str(cm.exception), msg)
reserved_names = [
'all',
'hook',
]
msg = 'breakpoint names "all" and "hook" are reserved'
for name in reserved_names:
with self.subTest(name=name):
with self.assertRaises(ValueError) as cm:
framework.breakpoint(name)
self.assertEqual(str(cm.exception), msg)
not_really_names = [
123,
1.1,
False,
]
for name in not_really_names:
with self.subTest(name=name):
with self.assertRaises(TypeError) as cm:
framework.breakpoint(name)
self.assertEqual(str(cm.exception), 'breakpoint names must be strings')
def check_trace_set(self, envvar_value, breakpoint_name, call_count):
"""Helper to check the diverse combinations of situations."""
with patch.dict(os.environ, {'JUJU_DEBUG_AT': envvar_value}):
framework = create_framework(self)
with patch('pdb.Pdb.set_trace') as mock:
framework.breakpoint(breakpoint_name)
self.assertEqual(mock.call_count, call_count)
def test_unnamed_indicated_all(self, fake_stderr):
# If 'all' is indicated, unnamed breakpoints will always activate.
self.check_trace_set('all', None, 1)
def test_unnamed_indicated_hook(self, fake_stderr):
# Special value 'hook' was indicated, nothing to do with any call.
self.check_trace_set('hook', None, 0)
def test_named_indicated_specifically(self, fake_stderr):
# Some breakpoint was indicated, and the framework call used exactly that name.
self.check_trace_set('mybreak', 'mybreak', 1)
def test_named_indicated_somethingelse(self, fake_stderr):
# Some breakpoint was indicated, but the framework call was not with that name.
self.check_trace_set('some-breakpoint', None, 0)
def test_named_indicated_ingroup(self, fake_stderr):
# A multiple breakpoint was indicated, and the framework call used a name among those.
self.check_trace_set('some,mybreak,foobar', 'mybreak', 1)
def test_named_indicated_all(self, fake_stderr):
# The framework indicated 'all', which includes any named breakpoint set.
self.check_trace_set('all', 'mybreak', 1)
def test_named_indicated_hook(self, fake_stderr):
# The framework indicated the special value 'hook', nothing to do with any named call.
self.check_trace_set('hook', 'mybreak', 0)
class DebugHookTests(unittest.TestCase):
def test_envvar_parsing_missing(self):
with patch.dict(os.environ):
os.environ.pop('JUJU_DEBUG_AT', None)
framework = create_framework(self)
self.assertEqual(framework._juju_debug_at, ())
def test_envvar_parsing_empty(self):
with patch.dict(os.environ, {'JUJU_DEBUG_AT': ''}):
framework = create_framework(self)
self.assertEqual(framework._juju_debug_at, ())
def test_envvar_parsing_simple(self):
with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'hook'}):
framework = create_framework(self)
self.assertEqual(framework._juju_debug_at, ['hook'])
def test_envvar_parsing_multiple(self):
with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'foo,bar,all'}):
framework = create_framework(self)
self.assertEqual(framework._juju_debug_at, ['foo', 'bar', 'all'])
def test_basic_interruption_enabled(self):
framework = create_framework(self)
framework._juju_debug_at = ['hook']
publisher = charm.CharmEvents(framework, "1")
observer = GenericObserver(framework, "1")
framework.observe(publisher.install, observer.callback_method)
with patch('sys.stderr', new_callable=io.StringIO) as fake_stderr:
with patch('pdb.runcall') as mock:
publisher.install.emit()
# Check that the pdb module was used correctly and that the callback method was NOT
# called (as we intercepted the normal pdb behaviour! this is to check that the
# framework didn't call the callback directly)
self.assertEqual(mock.call_count, 1)
expected_callback, expected_event = mock.call_args[0]
self.assertEqual(expected_callback, observer.callback_method)
self.assertIsInstance(expected_event, EventBase)
self.assertFalse(observer.called)
# Verify proper message was given to the user.
self.assertEqual(fake_stderr.getvalue(), _BREAKPOINT_WELCOME_MESSAGE)
def test_actions_are_interrupted(self):
test_model = create_model(self)
framework = create_framework(self, model=test_model)
framework._juju_debug_at = ['hook']
class CustomEvents(ObjectEvents):
foobar_action = EventSource(charm.ActionEvent)
publisher = CustomEvents(framework, "1")
observer = GenericObserver(framework, "1")
framework.observe(publisher.foobar_action, observer.callback_method)
fake_script(self, 'action-get', "echo {}")
with patch('sys.stderr', new_callable=io.StringIO):
with patch('pdb.runcall') as mock:
with patch.dict(os.environ, {'JUJU_ACTION_NAME': 'foobar'}):
publisher.foobar_action.emit()
self.assertEqual(mock.call_count, 1)
self.assertFalse(observer.called)
def test_internal_events_not_interrupted(self):
class MyNotifier(Object):
"""Generic notifier for the tests."""
bar = EventSource(EventBase)
framework = create_framework(self)
framework._juju_debug_at = ['hook']
publisher = MyNotifier(framework, "1")
observer = GenericObserver(framework, "1")
framework.observe(publisher.bar, observer.callback_method)
with patch('pdb.runcall') as mock:
publisher.bar.emit()
self.assertEqual(mock.call_count, 0)
self.assertTrue(observer.called)
def test_envvar_mixed(self):
framework = create_framework(self)
framework._juju_debug_at = ['foo', 'hook', 'all', 'whatever']
publisher = charm.CharmEvents(framework, "1")
observer = GenericObserver(framework, "1")
framework.observe(publisher.install, observer.callback_method)
with patch('sys.stderr', new_callable=io.StringIO):
with patch('pdb.runcall') as mock:
publisher.install.emit()
self.assertEqual(mock.call_count, 1)
self.assertFalse(observer.called)
def test_no_registered_method(self):
framework = create_framework(self)
framework._juju_debug_at = ['hook']
publisher = charm.CharmEvents(framework, "1")
observer = GenericObserver(framework, "1")
with patch('pdb.runcall') as mock:
publisher.install.emit()
self.assertEqual(mock.call_count, 0)
self.assertFalse(observer.called)
def test_envvar_nohook(self):
framework = create_framework(self)
framework._juju_debug_at = ['something-else']
publisher = charm.CharmEvents(framework, "1")
observer = GenericObserver(framework, "1")
framework.observe(publisher.install, observer.callback_method)
with patch.dict(os.environ, {'JUJU_DEBUG_AT': 'something-else'}):
with patch('pdb.runcall') as mock:
publisher.install.emit()
self.assertEqual(mock.call_count, 0)
self.assertTrue(observer.called)
def test_envvar_missing(self):
framework = create_framework(self)
framework._juju_debug_at = ()
publisher = charm.CharmEvents(framework, "1")
observer = GenericObserver(framework, "1")
framework.observe(publisher.install, observer.callback_method)
with patch('pdb.runcall') as mock:
publisher.install.emit()
self.assertEqual(mock.call_count, 0)
self.assertTrue(observer.called)
def test_welcome_message_not_multiple(self):
framework = create_framework(self)
framework._juju_debug_at = ['hook']
publisher = charm.CharmEvents(framework, "1")
observer = GenericObserver(framework, "1")
framework.observe(publisher.install, observer.callback_method)
with patch('sys.stderr', new_callable=io.StringIO) as fake_stderr:
with patch('pdb.runcall') as mock:
publisher.install.emit()
self.assertEqual(fake_stderr.getvalue(), _BREAKPOINT_WELCOME_MESSAGE)
publisher.install.emit()
self.assertEqual(fake_stderr.getvalue(), _BREAKPOINT_WELCOME_MESSAGE)
self.assertEqual(mock.call_count, 2)
# Copyright 2019 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import subprocess
import shutil
import tempfile
import unittest
def fake_script(test_case, name, content):
if not hasattr(test_case, 'fake_script_path'):
fake_script_path = tempfile.mkdtemp('-fake_script')
os.environ['PATH'] = '{}:{}'.format(fake_script_path, os.environ["PATH"])
def cleanup():
shutil.rmtree(fake_script_path)
os.environ['PATH'] = os.environ['PATH'].replace(fake_script_path + ':', '')
test_case.addCleanup(cleanup)
test_case.fake_script_path = pathlib.Path(fake_script_path)
with (test_case.fake_script_path / name).open('wt') as f:
# Before executing the provided script, dump the provided arguments in calls.txt.
f.write('''#!/bin/bash
{ echo -n $(basename $0); printf ";%s" "$@"; echo; } >> $(dirname $0)/calls.txt
''' + content)
os.chmod(str(test_case.fake_script_path / name), 0o755)
def fake_script_calls(test_case, clear=False):
try:
with (test_case.fake_script_path / 'calls.txt').open('r+t') as f:
calls = [line.split(';') for line in f.read().splitlines()]
if clear:
f.truncate(0)
return calls
except FileNotFoundError:
return []
class FakeScriptTest(unittest.TestCase):
def test_fake_script_works(self):
fake_script(self, 'foo', 'echo foo runs')
fake_script(self, 'bar', 'echo bar runs')
output = subprocess.getoutput('foo a "b c "; bar "d e" f')
self.assertEqual(output, 'foo runs\nbar runs')
self.assertEqual(fake_script_calls(self), [
['foo', 'a', 'b c '],
['bar', 'd e', 'f'],
])
def test_fake_script_clear(self):
fake_script(self, 'foo', 'echo foo runs')
output = subprocess.getoutput('foo a "b c"')
self.assertEqual(output, 'foo runs')
self.assertEqual(fake_script_calls(self, clear=True), [['foo', 'a', 'b c']])
fake_script(self, 'bar', 'echo bar runs')
output = subprocess.getoutput('bar "d e" f')
self.assertEqual(output, 'bar runs')
self.assertEqual(fake_script_calls(self, clear=True), [['bar', 'd e', 'f']])
self.assertEqual(fake_script_calls(self, clear=True), [])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment