Commit 6e969127 authored by aticig's avatar aticig Committed by Mark Beierl
Browse files

Adding updated VNFD package for robot test Basic25

parent 237506aa
##
# Copyright 2020 Canonical Ltd.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
##
touch:
description: "Touch a file on the VNF."
params:
filename:
description: "The name of the file to touch."
type: string
default: ""
required:
- filename
mkdir:
description: "Create a folder on the VNF."
params:
foldername:
description: "The name of the folder to create."
type: string
default: ""
required:
- foldername
# Standard OSM functions
start:
description: "Stop the service on the VNF."
stop:
description: "Stop the service on the VNF."
restart:
description: "Stop the service on the VNF."
reboot:
description: "Reboot the VNF virtual machine."
upgrade:
description: "Upgrade the software on the VNF."
# Required by charms.osm.sshproxy
run:
description: "Run an arbitrary command"
params:
command:
description: "The command to execute."
type: string
default: ""
required:
- command
generate-ssh-key:
description: "Generate a new SSH keypair for this unit. This will replace any existing previously generated keypair."
verify-ssh-credentials:
description: "Verify that this unit can authenticate with server specified by ssh-hostname and ssh-username."
get-ssh-public-key:
description: "Get the public SSH key for this unit."
##
# Copyright 2020 Canonical Ltd.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
##
options:
ssh-hostname:
type: string
default: ""
description: "The hostname or IP address of the machine to"
ssh-username:
type: string
default: ""
description: "The username to login as."
ssh-password:
type: string
default: ""
description: "The password used to authenticate."
ssh-public-key:
type: string
default: ""
description: "The public key of this unit."
ssh-key-type:
type: string
default: "rsa"
description: "The type of encryption to use for the SSH key."
ssh-key-bits:
type: int
default: 4096
description: "The number of bits to use for the SSH key."
#!/bin/sh
JUJU_DISPATCH_PATH="${JUJU_DISPATCH_PATH:-$0}" PYTHONPATH=lib:venv \
exec ./src/charm.py
#!/bin/sh
JUJU_DISPATCH_PATH="${JUJU_DISPATCH_PATH:-$0}" PYTHONPATH=lib:venv \
exec ./src/charm.py
#!/bin/sh
JUJU_DISPATCH_PATH="${JUJU_DISPATCH_PATH:-$0}" PYTHONPATH=lib:venv \
exec ./src/charm.py
#!/bin/sh
JUJU_DISPATCH_PATH="${JUJU_DISPATCH_PATH:-$0}" PYTHONPATH=lib:venv \
exec ./src/charm.py
##
# Copyright 2020 Canonical Ltd.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
##
import fnmatch
import os
import yaml
import subprocess
import sys
sys.path.append("lib")
import charmhelpers.fetch
ansible_hosts_path = "/etc/ansible/hosts"
def install_ansible_support(from_ppa=True, ppa_location="ppa:ansible/ansible"):
"""Installs the ansible package.
By default it is installed from the `PPA`_ linked from
the ansible `website`_ or from a ppa specified by a charm config..
.. _PPA: https://launchpad.net/~rquillo/+archive/ansible
.. _website: http://docs.ansible.com/intro_installation.html#latest-releases-via-apt-ubuntu
If from_ppa is empty, you must ensure that the package is available
from a configured repository.
"""
if from_ppa:
charmhelpers.fetch.add_source(ppa_location)
charmhelpers.fetch.apt_update(fatal=True)
charmhelpers.fetch.apt_install("ansible")
with open(ansible_hosts_path, "w+") as hosts_file:
hosts_file.write("localhost ansible_connection=local")
def create_hosts(hostname, username, password, hosts):
inventory_path = "/etc/ansible/hosts"
with open(inventory_path, "w") as f:
f.write("[{}]\n".format(hosts))
h1 = "host ansible_host={0} ansible_user={1} ansible_password={2}\n".format(
hostname, username, password
)
f.write(h1)
def create_ansible_cfg():
ansible_config_path = "/etc/ansible/ansible.cfg"
with open(ansible_config_path, "w") as f:
f.write("[defaults]\n")
f.write("host_key_checking = False\n")
# Function to find the playbook path
def find(pattern, path):
result = ""
for root, dirs, files in os.walk(path):
for name in files:
if fnmatch.fnmatch(name, pattern):
result = os.path.join(root, name)
return result
def execute_playbook(playbook_file, hostname, user, password, vars_dict=None):
playbook_path = find(playbook_file, "/var/lib/juju/agents/")
with open(playbook_path, "r") as f:
playbook_data = yaml.load(f)
hosts = "all"
if "hosts" in playbook_data[0].keys() and playbook_data[0]["hosts"]:
hosts = playbook_data[0]["hosts"]
create_ansible_cfg()
create_hosts(hostname, user, password, hosts)
call = "ansible-playbook {} ".format(playbook_path)
if vars_dict and isinstance(vars_dict, dict) and len(vars_dict) > 0:
call += "--extra-vars "
string_var = ""
for k,v in vars_dict.items():
string_var += "{}={} ".format(k, v)
string_var = string_var.strip()
call += '"{}"'.format(string_var)
call = call.strip()
result = subprocess.check_output(call, shell=True)
return result
# A prototype of a library to aid in the development and operation of
# OSM Network Service charms
import asyncio
import logging
import os
import os.path
import re
import subprocess
import sys
import time
import yaml
try:
import juju
except ImportError:
# Not all cloud images are created equal
if not os.path.exists("/usr/bin/python3") or not os.path.exists("/usr/bin/pip3"):
# Update the apt cache
subprocess.check_call(["apt-get", "update"])
# Install the Python3 package
subprocess.check_call(["apt-get", "install", "-y", "python3", "python3-pip"],)
# Install the libjuju build dependencies
subprocess.check_call(["apt-get", "install", "-y", "libffi-dev", "libssl-dev"],)
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "juju"],
)
from juju.controller import Controller
# Quiet the debug logging
logging.getLogger('websockets.protocol').setLevel(logging.INFO)
logging.getLogger('juju.client.connection').setLevel(logging.WARN)
logging.getLogger('juju.model').setLevel(logging.WARN)
logging.getLogger('juju.machine').setLevel(logging.WARN)
class NetworkService:
"""A lightweight interface to the Juju controller.
This NetworkService client is specifically designed to allow a higher-level
"NS" charm to interoperate with "VNF" charms, allowing for the execution of
Primitives across other charms within the same model.
"""
endpoint = None
user = 'admin'
secret = None
port = 17070
loop = None
client = None
model = None
cacert = None
def __init__(self, user, secret, endpoint=None):
self.user = user
self.secret = secret
if endpoint is None:
addresses = os.environ['JUJU_API_ADDRESSES']
for address in addresses.split(' '):
self.endpoint = address
else:
self.endpoint = endpoint
# Stash the name of the model
self.model = os.environ['JUJU_MODEL_NAME']
# Load the ca-cert from agent.conf
AGENT_PATH = os.path.dirname(os.environ['JUJU_CHARM_DIR'])
with open("{}/agent.conf".format(AGENT_PATH), "r") as f:
try:
y = yaml.safe_load(f)
self.cacert = y['cacert']
except yaml.YAMLError as exc:
print("Unable to find Juju ca-cert.")
raise exc
# Create our event loop
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
async def connect(self):
"""Connect to the Juju controller."""
controller = Controller()
print(
"Connecting to controller... ws://{}:{} as {}/{}".format(
self.endpoint,
self.port,
self.user,
self.secret[-4:].rjust(len(self.secret), "*"),
)
)
await controller.connect(
endpoint=self.endpoint,
username=self.user,
password=self.secret,
cacert=self.cacert,
)
return controller
def __del__(self):
self.logout()
async def disconnect(self):
"""Disconnect from the Juju controller."""
if self.client:
print("Disconnecting Juju controller")
await self.client.disconnect()
def login(self):
"""Login to the Juju controller."""
if not self.client:
# Connect to the Juju API server
self.client = self.loop.run_until_complete(self.connect())
return self.client
def logout(self):
"""Logout of the Juju controller."""
if self.loop:
print("Disconnecting from API")
self.loop.run_until_complete(self.disconnect())
def FormatApplicationName(self, *args):
"""
Generate a Juju-compatible Application name
:param args tuple: Positional arguments to be used to construct the
application name.
Limitations::
- Only accepts characters a-z and non-consequitive dashes (-)
- Application name should not exceed 50 characters
Examples::
FormatApplicationName("ping_pong_ns", "ping_vnf", "a")
"""
appname = ""
for c in "-".join(list(args)):
if c.isdigit():
c = chr(97 + int(c))
elif not c.isalpha():
c = "-"
appname += c
return re.sub('-+', '-', appname.lower())
def GetApplicationName(self, nsr_name, vnf_name, vnf_member_index):
"""Get the runtime application name of a VNF/VDU.
This will generate an application name matching the name of the deployed charm,
given the right parameters.
:param nsr_name str: The name of the running Network Service, as specified at instantiation.
:param vnf_name str: The name of the VNF or VDU
:param vnf_member_index: The vnf-member-index as specified in the descriptor
"""
application_name = self.FormatApplicationName(nsr_name, vnf_member_index, vnf_name)
# This matches the logic used by the LCM
application_name = application_name[0:48]
vca_index = int(vnf_member_index) - 1
application_name += '-' + chr(97 + vca_index // 26) + chr(97 + vca_index % 26)
return application_name
def ExecutePrimitiveGetOutput(self, application, primitive, params={}, timeout=600):
"""Execute a single primitive and return it's output.
This is a blocking method that will execute a single primitive and wait
for its completion before return it's output.
:param application str: The application name provided by `GetApplicationName`.
:param primitive str: The name of the primitive to execute.
:param params list: A list of parameters.
:param timeout int: A timeout, in seconds, to wait for the primitive to finish. Defaults to 600 seconds.
"""
uuid = self.ExecutePrimitive(application, primitive, params)
status = None
output = None
starttime = time.time()
while(time.time() < starttime + timeout):
status = self.GetPrimitiveStatus(uuid)
if status in ['completed', 'failed']:
break
time.sleep(10)
# When the primitive is done, get the output
if status in ['completed', 'failed']:
output = self.GetPrimitiveOutput(uuid)
return output
def ExecutePrimitive(self, application, primitive, params={}):
"""Execute a primitive.
This is a non-blocking method to execute a primitive. It will return
the UUID of the queued primitive execution, which you can use
for subsequent calls to `GetPrimitiveStatus` and `GetPrimitiveOutput`.
:param application string: The name of the application
:param primitive string: The name of the Primitive.
:param params list: A list of parameters.
:returns uuid string: The UUID of the executed Primitive
"""
uuid = None
if not self.client:
self.login()
model = self.loop.run_until_complete(
self.client.get_model(self.model)
)
# Get the application
if application in model.applications:
app = model.applications[application]
# Execute the primitive
unit = app.units[0]
if unit:
action = self.loop.run_until_complete(
unit.run_action(primitive, **params)
)
uuid = action.id
print("Executing action: {}".format(uuid))
self.loop.run_until_complete(
model.disconnect()
)
else:
# Invalid mapping: application not found. Raise exception
raise Exception("Application not found: {}".format(application))
return uuid
def GetPrimitiveStatus(self, uuid):
"""Get the status of a Primitive execution.
This will return one of the following strings:
- pending
- running
- completed
- failed
:param uuid string: The UUID of the executed Primitive.
:returns: The status of the executed Primitive
"""
status = None
if not self.client:
self.login()
model = self.loop.run_until_complete(
self.client.get_model(self.model)
)
status = self.loop.run_until_complete(
model.get_action_status(uuid)
)
self.loop.run_until_complete(
model.disconnect()
)
return status[uuid]
def GetPrimitiveOutput(self, uuid):
"""Get the output of a completed Primitive execution.
:param uuid string: The UUID of the executed Primitive.
:returns: The output of the execution, or None if it's still running.
"""
result = None
if not self.client:
self.login()
model = self.loop.run_until_complete(
self.client.get_model(self.model)
)
result = self.loop.run_until_complete(
model.get_action_output(uuid)
)
self.loop.run_until_complete(
model.disconnect()
)
return result
import socket
from ops.framework import Object, StoredState
class ProxyCluster(Object):
state = StoredState()
def __init__(self, charm, relation_name):
super().__init__(charm, relation_name)
self._relation_name = relation_name
self._relation = self.framework.model.get_relation(self._relation_name)
self.framework.observe(charm.on.ssh_keys_initialized, self.on_ssh_keys_initialized)
self.state.set_default(ssh_public_key=None)
self.state.set_default(ssh_private_key=None)
def on_ssh_keys_initialized(self, event):
if not self.framework.model.unit.is_leader():
raise RuntimeError("The initial unit of a cluster must also be a leader.")
self.state.ssh_public_key = event.ssh_public_key
self.state.ssh_private_key = event.ssh_private_key
if not self.is_joined:
event.defer()
return
self._relation.data[self.model.app][
"ssh_public_key"
] = self.state.ssh_public_key
self._relation.data[self.model.app][
"ssh_private_key"
] = self.state.ssh_private_key
@property
def is_joined(self):
return self._relation is not None
@property
def ssh_public_key(self):
if self.is_joined:
return self._relation.data[self.model.app].get("ssh_public_key")
@property
def ssh_private_key(self):
if self.is_joined:
return self._relation.data[self.model.app].get("ssh_private_key")
@property
def is_cluster_initialized(self):
return (
True
if self.is_joined
and self._relation.data[self.model.app].get("ssh_public_key")
and self._relation.data[self.model.app].get("ssh_private_key")
else False
)
"""Module to help with executing commands over SSH."""
##
# Copyright 2016 Canonical Ltd.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
##
# from charmhelpers.core import unitdata
# from charmhelpers.core.hookenv import log
import io
import ipaddress
import subprocess
import os
import socket
import shlex
import traceback
import sys
from subprocess import (
check_call,
Popen,
CalledProcessError,
PIPE,
)
from ops.charm import CharmBase, CharmEvents
from ops.framework import StoredState, EventBase, EventSource
from ops.main import main
from ops.model import (
ActiveStatus,
BlockedStatus,
MaintenanceStatus,
WaitingStatus,
ModelError,
)
import os
import subprocess
from .proxy_cluster import ProxyCluster
import logging
logger = logging.getLogger(__name__)
class SSHKeysInitialized(EventBase):
def __init__(self, handle, ssh_public_key, ssh_private_key):
super().__init__(handle)
self.ssh_public_key = ssh_public_key
self.ssh_private_key = ssh_private_key
def snapshot(self):
return {
"ssh_public_key": self.ssh_public_key,
"ssh_private_key": self.ssh_private_key,
}
def restore(self, snapshot):
self.ssh_public_key = snapshot["ssh_public_key"]
self.ssh_private_key = snapshot["ssh_private_key"]
class ProxyClusterEvents(CharmEvents):
ssh_keys_initialized = EventSource(SSHKeysInitialized)
class SSHProxyCharm(CharmBase):
state = StoredState()
on = ProxyClusterEvents()
def __init__(self, framework, key):
super().__init__(framework, key)
self.peers = ProxyCluster(self, "proxypeer")
# SSH Proxy actions (primitives)
self.framework.observe(self.on.generate_ssh_key_action, self.on_generate_ssh_key_action)
self.framework.observe(self.on.get_ssh_public_key_action, self.on_get_ssh_public_key_action)
self.framework.observe(self.on.run_action, self.on_run_action)
self.framework.observe(self.on.verify_ssh_credentials_action, self.on_verify_ssh_credentials_action)
self.framework.observe(self.on.proxypeer_relation_changed, self.on_proxypeer_relation_changed)
def get_ssh_proxy(self):
"""Get the SSHProxy instance"""
proxy = SSHProxy(
hostname=self.model.config["ssh-hostname"],
username=self.model.config["ssh-username"],
password=self.model.config["ssh-password"],
)
return proxy
def on_proxypeer_relation_changed(self, event):
if self.peers.is_cluster_initialized and not SSHProxy.has_ssh_key():
pubkey = self.peers.ssh_public_key
privkey = self.peers.ssh_private_key
SSHProxy.write_ssh_keys(public=pubkey, private=privkey)
self.verify_credentials()
else:
event.defer()
def on_config_changed(self, event):
"""Handle changes in configuration"""
self.verify_credentials()
def on_install(self, event):
SSHProxy.install()
def on_start(self, event):
"""Called when the charm is being installed"""
if not self.peers.is_joined:
event.defer()
return
unit = self.model.unit
if not SSHProxy.has_ssh_key():
unit.status = MaintenanceStatus("Generating SSH keys...")
pubkey = None
privkey = None
if self.model.unit.is_leader():
if self.peers.is_cluster_initialized:
SSHProxy.write_ssh_keys(
public=self.peers.ssh_public_key,
private=self.peers.ssh_private_key,
)
else:
SSHProxy.generate_ssh_key()
self.on.ssh_keys_initialized.emit(
SSHProxy.get_ssh_public_key(), SSHProxy.get_ssh_private_key()
)
self.verify_credentials()
def verify_credentials(self):
unit = self.model.unit
# Unit should go into a waiting state until verify_ssh_credentials is successful
unit.status = WaitingStatus("Waiting for SSH credentials")
proxy = self.get_ssh_proxy()
verified, _ = proxy.verify_credentials()
if verified:
unit.status = ActiveStatus()
else:
unit.status = BlockedStatus("Invalid SSH credentials.")
return verified
#####################
# SSH Proxy methods #
#####################
def on_generate_ssh_key_action(self, event):
"""Generate a new SSH keypair for this unit."""
if self.model.unit.is_leader():
if not SSHProxy.generate_ssh_key():
event.fail("Unable to generate ssh key")
else:
event.fail("Unit is not leader")
return
def on_get_ssh_public_key_action(self, event):
"""Get the SSH public key for this unit."""
if self.model.unit.is_leader():
pubkey = SSHProxy.get_ssh_public_key()
event.set_results({"pubkey": SSHProxy.get_ssh_public_key()})
else:
event.fail("Unit is not leader")
return
def on_run_action(self, event):
"""Run an arbitrary command on the remote host."""
if self.model.unit.is_leader():
cmd = event.params["command"]
proxy = self.get_ssh_proxy()
stdout, stderr = proxy.run(cmd)
event.set_results({"output": stdout})
if len(stderr):
event.fail(stderr)
else:
event.fail("Unit is not leader")
return
def on_verify_ssh_credentials_action(self, event):
"""Verify the SSH credentials for this unit."""
unit = self.model.unit
if unit.is_leader():
proxy = self.get_ssh_proxy()
verified, stderr = proxy.verify_credentials()
if verified:
event.set_results({"verified": True})
unit.status = ActiveStatus()
else:
event.set_results({"verified": False, "stderr": stderr})
event.fail("Not verified")
unit.status = BlockedStatus("Invalid SSH credentials.")
else:
event.fail("Unit is not leader")
return
class LeadershipError(ModelError):
def __init__(self):
super().__init__("not leader")
class SSHProxy:
private_key_path = "/root/.ssh/id_sshproxy"
public_key_path = "/root/.ssh/id_sshproxy.pub"
key_type = "rsa"
key_bits = 4096
def __init__(self, hostname: str, username: str, password: str = ""):
self.hostname = hostname
self.username = username
self.password = password
@staticmethod
def install():
check_call("apt update && apt install -y openssh-client sshpass", shell=True)
@staticmethod
def generate_ssh_key():
"""Generate a 4096-bit rsa keypair."""
if not os.path.exists(SSHProxy.private_key_path):
cmd = "ssh-keygen -t {} -b {} -N '' -f {}".format(
SSHProxy.key_type, SSHProxy.key_bits, SSHProxy.private_key_path,
)
try:
check_call(cmd, shell=True)
except CalledProcessError:
return False
return True
@staticmethod
def write_ssh_keys(public, private):
"""Write a 4096-bit rsa keypair."""
with open(SSHProxy.public_key_path, "w") as f:
f.write(public)
f.close()
with open(SSHProxy.private_key_path, "w") as f:
f.write(private)
f.close()
@staticmethod
def get_ssh_public_key():
publickey = ""
if os.path.exists(SSHProxy.private_key_path):
with open(SSHProxy.public_key_path, "r") as f:
publickey = f.read()
return publickey
@staticmethod
def get_ssh_private_key():
privatekey = ""
if os.path.exists(SSHProxy.private_key_path):
with open(SSHProxy.private_key_path, "r") as f:
privatekey = f.read()
return privatekey
@staticmethod
def has_ssh_key():
return True if os.path.exists(SSHProxy.private_key_path) else False
def run(self, cmd: str) -> (str, str):
"""Run a command remotely via SSH.
Note: The previous behavior was to run the command locally if SSH wasn't
configured, but that can lead to cases where execution succeeds when you'd
expect it not to.
"""
if isinstance(cmd, str):
cmd = shlex.split(cmd)
host = self._get_hostname()
user = self.username
passwd = self.password
key = self.private_key_path
# Make sure we have everything we need to connect
if host and user:
return self.ssh(cmd)
raise Exception("Invalid SSH credentials.")
def scp(self, source_file, destination_file):
"""Execute an scp command. Requires a fully qualified source and
destination.
:param str source_file: Path to the source file
:param str destination_file: Path to the destination file
:raises: :class:`CalledProcessError` if the command fails
"""
cmd = [
"sshpass",
"-p",
self.password,
"scp",
"-i",
os.path.expanduser(self.private_key_path),
"-o",
"StrictHostKeyChecking=no",
"-q",
"-B",
]
destination = "{}@{}:{}".format(self.username, self.hostname, destination_file)
cmd.extend([source_file, destination])
subprocess.run(cmd, check=True)
def ssh(self, command):
"""Run a command remotely via SSH.
:param list(str) command: The command to execute
:return: tuple: The stdout and stderr of the command execution
:raises: :class:`CalledProcessError` if the command fails
"""
destination = "{}@{}".format(self.username, self.hostname)
cmd = [
"sshpass",
"-p",
self.password,
"ssh",
"-i",
os.path.expanduser(self.private_key_path),
"-o",
"StrictHostKeyChecking=no",
"-q",
destination,
]
cmd.extend(command)
output = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return (output.stdout.decode("utf-8").strip(), output.stderr.decode("utf-8").strip())
def verify_credentials(self):
"""Verify the SSH credentials.
:return (bool, str): Verified, Stderr
"""
verified = False
try:
(stdout, stderr) = self.run("hostname")
verified = True
except CalledProcessError as e:
stderr = "Command failed: {} ({})".format(" ".join(e.cmd), str(e.output))
except (TimeoutError, socket.timeout):
stderr = "Timeout attempting to reach {}".format(self._get_hostname())
except Exception as error:
tb = traceback.format_exc()
stderr = "Unhandled exception: {}".format(tb)
return verified, stderr
###################
# Private methods #
###################
def _get_hostname(self):
"""Get the hostname for the ssh target.
HACK: This function was added to work around an issue where the
ssh-hostname was passed in the format of a.b.c.d;a.b.c.d, where the first
is the floating ip, and the second the non-floating ip, for an Openstack
instance.
"""
return self.hostname.split(";")[0]
# Copyright 2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Operator Framework."""
from .version import version as __version__ # noqa: F401 (imported but unused)
# Import here the bare minimum to break the circular import between modules
from . import charm # noqa: F401 (imported but unused)
# Copyright 2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from functools import total_ordering
@total_ordering
class JujuVersion:
PATTERN = r'''^
(?P<major>\d{1,9})\.(?P<minor>\d{1,9}) # <major> and <minor> numbers are always there
((?:\.|-(?P<tag>[a-z]+))(?P<patch>\d{1,9}))? # sometimes with .<patch> or -<tag><patch>
(\.(?P<build>\d{1,9}))?$ # and sometimes with a <build> number.
'''
def __init__(self, version):
m = re.match(self.PATTERN, version, re.VERBOSE)
if not m:
raise RuntimeError('"{}" is not a valid Juju version string'.format(version))
d = m.groupdict()
self.major = int(m.group('major'))
self.minor = int(m.group('minor'))
self.tag = d['tag'] or ''
self.patch = int(d['patch'] or 0)
self.build = int(d['build'] or 0)
def __repr__(self):
if self.tag:
s = '{}.{}-{}{}'.format(self.major, self.minor, self.tag, self.patch)
else:
s = '{}.{}.{}'.format(self.major, self.minor, self.patch)
if self.build > 0:
s += '.{}'.format(self.build)
return s
def __eq__(self, other):
if self is other:
return True
if isinstance(other, str):
other = type(self)(other)
elif not isinstance(other, JujuVersion):
raise RuntimeError('cannot compare Juju version "{}" with "{}"'.format(self, other))
return (
self.major == other.major
and self.minor == other.minor
and self.tag == other.tag
and self.build == other.build
and self.patch == other.patch)
def __lt__(self, other):
if self is other:
return False
if isinstance(other, str):
other = type(self)(other)
elif not isinstance(other, JujuVersion):
raise RuntimeError('cannot compare Juju version "{}" with "{}"'.format(self, other))
if self.major != other.major:
return self.major < other.major
elif self.minor != other.minor:
return self.minor < other.minor
elif self.tag != other.tag:
if not self.tag:
return False
elif not other.tag:
return True
return self.tag < other.tag
elif self.patch != other.patch:
return self.patch < other.patch
elif self.build != other.build:
return self.build < other.build
return False
@classmethod
def from_environ(cls) -> 'JujuVersion':
"""Build a JujuVersion from JUJU_VERSION."""
v = os.environ.get('JUJU_VERSION')
if not v:
raise RuntimeError('environ has no JUJU_VERSION')
return cls(v)
def has_app_data(self) -> bool:
"""Determine whether this juju version knows about app data."""
return (self.major, self.minor, self.patch) >= (2, 7, 0)
# Copyright 2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import re
from ast import literal_eval
from importlib.util import module_from_spec
from importlib.machinery import ModuleSpec
from pkgutil import get_importer
from types import ModuleType
_libraries = None
_libline_re = re.compile(r'''^LIB([A-Z]+)\s*=\s*([0-9]+|['"][a-zA-Z0-9_.\-@]+['"])''')
_libname_re = re.compile(r'''^[a-z][a-z0-9]+$''')
# Not perfect, but should do for now.
_libauthor_re = re.compile(r'''^[A-Za-z0-9_+.-]+@[a-z0-9_-]+(?:\.[a-z0-9_-]+)*\.[a-z]{2,3}$''')
def use(name: str, api: int, author: str) -> ModuleType:
"""Use a library from the ops libraries.
Args:
name: the name of the library requested.
api: the API version of the library.
author: the author of the library. If not given, requests the
one in the standard library.
Raises:
ImportError: if the library cannot be found.
TypeError: if the name, api, or author are the wrong type.
ValueError: if the name, api, or author are invalid.
"""
if not isinstance(name, str):
raise TypeError("invalid library name: {!r} (must be a str)".format(name))
if not isinstance(author, str):
raise TypeError("invalid library author: {!r} (must be a str)".format(author))
if not isinstance(api, int):
raise TypeError("invalid library API: {!r} (must be an int)".format(api))
if api < 0:
raise ValueError('invalid library api: {} (must be ≥0)'.format(api))
if not _libname_re.match(name):
raise ValueError("invalid library name: {!r} (chars and digits only)".format(name))
if not _libauthor_re.match(author):
raise ValueError("invalid library author email: {!r}".format(author))
if _libraries is None:
autoimport()
versions = _libraries.get((name, author), ())
for lib in versions:
if lib.api == api:
return lib.import_module()
others = ', '.join(str(lib.api) for lib in versions)
if others:
msg = 'cannot find "{}" from "{}" with API version {} (have {})'.format(
name, author, api, others)
else:
msg = 'cannot find library "{}" from "{}"'.format(name, author)
raise ImportError(msg, name=name)
def autoimport():
"""Find all libs in the path and enable use of them.
You only need to call this if you've installed a package or
otherwise changed sys.path in the current run, and need to see the
changes. Otherwise libraries are found on first call of `use`.
"""
global _libraries
_libraries = {}
for spec in _find_all_specs(sys.path):
lib = _parse_lib(spec)
if lib is None:
continue
versions = _libraries.setdefault((lib.name, lib.author), [])
versions.append(lib)
versions.sort(reverse=True)
def _find_all_specs(path):
for sys_dir in path:
if sys_dir == "":
sys_dir = "."
try:
top_dirs = os.listdir(sys_dir)
except OSError:
continue
for top_dir in top_dirs:
opslib = os.path.join(sys_dir, top_dir, 'opslib')
try:
lib_dirs = os.listdir(opslib)
except OSError:
continue
finder = get_importer(opslib)
if finder is None or not hasattr(finder, 'find_spec'):
continue
for lib_dir in lib_dirs:
spec = finder.find_spec(lib_dir)
if spec is None:
continue
if spec.loader is None:
# a namespace package; not supported
continue
yield spec
# only the first this many lines of a file are looked at for the LIB* constants
_MAX_LIB_LINES = 99
def _parse_lib(spec):
if spec.origin is None:
return None
_expected = {'NAME': str, 'AUTHOR': str, 'API': int, 'PATCH': int}
try:
with open(spec.origin, 'rt', encoding='utf-8') as f:
libinfo = {}
for n, line in enumerate(f):
if len(libinfo) == len(_expected):
break
if n > _MAX_LIB_LINES:
return None
m = _libline_re.match(line)
if m is None:
continue
key, value = m.groups()
if key in _expected:
value = literal_eval(value)
if not isinstance(value, _expected[key]):
return None
libinfo[key] = value
else:
if len(libinfo) != len(_expected):
return None
except Exception:
return None
return _Lib(spec, libinfo['NAME'], libinfo['AUTHOR'], libinfo['API'], libinfo['PATCH'])
class _Lib:
def __init__(self, spec: ModuleSpec, name: str, author: str, api: int, patch: int):
self.spec = spec
self.name = name
self.author = author
self.api = api
self.patch = patch
self._module = None
def __repr__(self):
return "<_Lib {0.name} by {0.author}, API {0.api}, patch {0.patch}>".format(self)
def import_module(self) -> ModuleType:
if self._module is None:
module = module_from_spec(self.spec)
self.spec.loader.exec_module(module)
self._module = module
return self._module
def __eq__(self, other):
if not isinstance(other, _Lib):
return NotImplemented
a = (self.name, self.author, self.api, self.patch)
b = (other.name, other.author, other.api, other.patch)
return a == b
def __lt__(self, other):
if not isinstance(other, _Lib):
return NotImplemented
a = (self.name, self.author, self.api, self.patch)
b = (other.name, other.author, other.api, other.patch)
return a < b
# Copyright 2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import logging
class JujuLogHandler(logging.Handler):
"""A handler for sending logs to Juju via juju-log."""
def __init__(self, model_backend, level=logging.DEBUG):
super().__init__(level)
self.model_backend = model_backend
def emit(self, record):
self.model_backend.juju_log(record.levelname, self.format(record))
def setup_root_logging(model_backend, debug=False):
"""Setup python logging to forward messages to juju-log.
By default, logging is set to DEBUG level, and messages will be filtered by Juju.
Charmers can also set their own default log level with::
logging.getLogger().setLevel(logging.INFO)
model_backend -- a ModelBackend to use for juju-log
debug -- if True, write logs to stderr as well as to juju-log.
"""
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.addHandler(JujuLogHandler(model_backend))
if debug:
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
sys.excepthook = lambda etype, value, tb: logger.error(
"Uncaught exception while in charm code:", exc_info=(etype, value, tb))
# Copyright 2019-2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import os
import subprocess
import sys
import warnings
from pathlib import Path
import yaml
import ops.charm
import ops.framework
import ops.model
import ops.storage
from ops.log import setup_root_logging
CHARM_STATE_FILE = '.unit-state.db'
logger = logging.getLogger()
def _get_charm_dir():
charm_dir = os.environ.get("JUJU_CHARM_DIR")
if charm_dir is None:
# Assume $JUJU_CHARM_DIR/lib/op/main.py structure.
charm_dir = Path('{}/../../..'.format(__file__)).resolve()
else:
charm_dir = Path(charm_dir).resolve()
return charm_dir
def _create_event_link(charm, bound_event):
"""Create a symlink for a particular event.
charm -- A charm object.
bound_event -- An event for which to create a symlink.
"""
if issubclass(bound_event.event_type, ops.charm.HookEvent):
event_dir = charm.framework.charm_dir / 'hooks'
event_path = event_dir / bound_event.event_kind.replace('_', '-')
elif issubclass(bound_event.event_type, ops.charm.ActionEvent):
if not bound_event.event_kind.endswith("_action"):
raise RuntimeError(
'action event name {} needs _action suffix'.format(bound_event.event_kind))
event_dir = charm.framework.charm_dir / 'actions'
# The event_kind is suffixed with "_action" while the executable is not.
event_path = event_dir / bound_event.event_kind[:-len('_action')].replace('_', '-')
else:
raise RuntimeError(
'cannot create a symlink: unsupported event type {}'.format(bound_event.event_type))
event_dir.mkdir(exist_ok=True)
if not event_path.exists():
# CPython has different implementations for populating sys.argv[0] for Linux and Windows.
# For Windows it is always an absolute path (any symlinks are resolved)
# while for Linux it can be a relative path.
target_path = os.path.relpath(os.path.realpath(sys.argv[0]), str(event_dir))
# Ignore the non-symlink files or directories
# assuming the charm author knows what they are doing.
logger.debug(
'Creating a new relative symlink at %s pointing to %s',
event_path, target_path)
event_path.symlink_to(target_path)
def _setup_event_links(charm_dir, charm):
"""Set up links for supported events that originate from Juju.
Whether a charm can handle an event or not can be determined by
introspecting which events are defined on it.
Hooks or actions are created as symlinks to the charm code file
which is determined by inspecting symlinks provided by the charm
author at hooks/install or hooks/start.
charm_dir -- A root directory of the charm.
charm -- An instance of the Charm class.
"""
for bound_event in charm.on.events().values():
# Only events that originate from Juju need symlinks.
if issubclass(bound_event.event_type, (ops.charm.HookEvent, ops.charm.ActionEvent)):
_create_event_link(charm, bound_event)
def _emit_charm_event(charm, event_name):
"""Emits a charm event based on a Juju event name.
charm -- A charm instance to emit an event from.
event_name -- A Juju event name to emit on a charm.
"""
event_to_emit = None
try:
event_to_emit = getattr(charm.on, event_name)
except AttributeError:
logger.debug("Event %s not defined for %s.", event_name, charm)
# If the event is not supported by the charm implementation, do
# not error out or try to emit it. This is to support rollbacks.
if event_to_emit is not None:
args, kwargs = _get_event_args(charm, event_to_emit)
logger.debug('Emitting Juju event %s.', event_name)
event_to_emit.emit(*args, **kwargs)
def _get_event_args(charm, bound_event):
event_type = bound_event.event_type
model = charm.framework.model
if issubclass(event_type, ops.charm.RelationEvent):
relation_name = os.environ['JUJU_RELATION']
relation_id = int(os.environ['JUJU_RELATION_ID'].split(':')[-1])
relation = model.get_relation(relation_name, relation_id)
else:
relation = None
remote_app_name = os.environ.get('JUJU_REMOTE_APP', '')
remote_unit_name = os.environ.get('JUJU_REMOTE_UNIT', '')
if remote_app_name or remote_unit_name:
if not remote_app_name:
if '/' not in remote_unit_name:
raise RuntimeError('invalid remote unit name: {}'.format(remote_unit_name))
remote_app_name = remote_unit_name.split('/')[0]
args = [relation, model.get_app(remote_app_name)]
if remote_unit_name:
args.append(model.get_unit(remote_unit_name))
return args, {}
elif relation:
return [relation], {}
return [], {}
class _Dispatcher:
"""Encapsulate how to figure out what event Juju wants us to run.
Also knows how to run “legacy” hooks when Juju called us via a top-level
``dispatch`` binary.
Args:
charm_dir: the toplevel directory of the charm
Attributes:
event_name: the name of the event to run
is_dispatch_aware: are we running under a Juju that knows about the
dispatch binary?
"""
def __init__(self, charm_dir: Path):
self._charm_dir = charm_dir
self._exec_path = Path(sys.argv[0])
if 'JUJU_DISPATCH_PATH' in os.environ and (charm_dir / 'dispatch').exists():
self._init_dispatch()
else:
self._init_legacy()
def ensure_event_links(self, charm):
"""Make sure necessary symlinks are present on disk"""
if self.is_dispatch_aware:
# links aren't needed
return
# When a charm is force-upgraded and a unit is in an error state Juju
# does not run upgrade-charm and instead runs the failed hook followed
# by config-changed. Given the nature of force-upgrading the hook setup
# code is not triggered on config-changed.
#
# 'start' event is included as Juju does not fire the install event for
# K8s charms (see LP: #1854635).
if (self.event_name in ('install', 'start', 'upgrade_charm')
or self.event_name.endswith('_storage_attached')):
_setup_event_links(self._charm_dir, charm)
def run_any_legacy_hook(self):
"""Run any extant legacy hook.
If there is both a dispatch file and a legacy hook for the
current event, run the wanted legacy hook.
"""
if not self.is_dispatch_aware:
# we *are* the legacy hook
return
dispatch_path = self._charm_dir / self._dispatch_path
if not dispatch_path.exists():
logger.debug("Legacy %s does not exist.", self._dispatch_path)
return
# super strange that there isn't an is_executable
if not os.access(str(dispatch_path), os.X_OK):
logger.warning("Legacy %s exists but is not executable.", self._dispatch_path)
return
if dispatch_path.resolve() == self._exec_path.resolve():
logger.debug("Legacy %s is just a link to ourselves.", self._dispatch_path)
return
argv = sys.argv.copy()
argv[0] = str(dispatch_path)
logger.info("Running legacy %s.", self._dispatch_path)
try:
subprocess.run(argv, check=True)
except subprocess.CalledProcessError as e:
logger.warning(
"Legacy %s exited with status %d.",
self._dispatch_path, e.returncode)
sys.exit(e.returncode)
else:
logger.debug("Legacy %s exited with status 0.", self._dispatch_path)
def _set_name_from_path(self, path: Path):
"""Sets the name attribute to that which can be inferred from the given path."""
name = path.name.replace('-', '_')
if path.parent.name == 'actions':
name = '{}_action'.format(name)
self.event_name = name
def _init_legacy(self):
"""Set up the 'legacy' dispatcher.
The current Juju doesn't know about 'dispatch' and calls hooks
explicitly.
"""
self.is_dispatch_aware = False
self._set_name_from_path(self._exec_path)
def _init_dispatch(self):
"""Set up the new 'dispatch' dispatcher.
The current Juju will run 'dispatch' if it exists, and otherwise fall
back to the old behaviour.
JUJU_DISPATCH_PATH will be set to the wanted hook, e.g. hooks/install,
in both cases.
"""
self._dispatch_path = Path(os.environ['JUJU_DISPATCH_PATH'])
if 'OPERATOR_DISPATCH' in os.environ:
logger.debug("Charm called itself via %s.", self._dispatch_path)
sys.exit(0)
os.environ['OPERATOR_DISPATCH'] = '1'
self.is_dispatch_aware = True
self._set_name_from_path(self._dispatch_path)
def is_restricted_context(self):
""""Return True if we are running in a restricted Juju context.
When in a restricted context, most commands (relation-get, config-get,
state-get) are not available. As such, we change how we interact with
Juju.
"""
return self.event_name in ('collect_metrics',)
def main(charm_class, use_juju_for_storage=False):
"""Setup the charm and dispatch the observed event.
The event name is based on the way this executable was called (argv[0]).
"""
charm_dir = _get_charm_dir()
model_backend = ops.model._ModelBackend()
debug = ('JUJU_DEBUG' in os.environ)
setup_root_logging(model_backend, debug=debug)
logger.debug("Operator Framework %s up and running.", ops.__version__)
dispatcher = _Dispatcher(charm_dir)
dispatcher.run_any_legacy_hook()
metadata = (charm_dir / 'metadata.yaml').read_text()
actions_meta = charm_dir / 'actions.yaml'
if actions_meta.exists():
actions_metadata = actions_meta.read_text()
else:
actions_metadata = None
if not yaml.__with_libyaml__:
logger.debug('yaml does not have libyaml extensions, using slower pure Python yaml loader')
meta = ops.charm.CharmMeta.from_yaml(metadata, actions_metadata)
model = ops.model.Model(meta, model_backend)
# TODO: If Juju unit agent crashes after exit(0) from the charm code
# the framework will commit the snapshot but Juju will not commit its
# operation.
charm_state_path = charm_dir / CHARM_STATE_FILE
if use_juju_for_storage:
if dispatcher.is_restricted_context():
# TODO: jam 2020-06-30 This unconditionally avoids running a collect metrics event
# Though we eventually expect that juju will run collect-metrics in a
# non-restricted context. Once we can determine that we are running collect-metrics
# in a non-restricted context, we should fire the event as normal.
logger.debug('"%s" is not supported when using Juju for storage\n'
'see: https://github.com/canonical/operator/issues/348',
dispatcher.event_name)
# Note that we don't exit nonzero, because that would cause Juju to rerun the hook
return
store = ops.storage.JujuStorage()
else:
store = ops.storage.SQLiteStorage(charm_state_path)
framework = ops.framework.Framework(store, charm_dir, meta, model)
try:
sig = inspect.signature(charm_class)
try:
sig.bind(framework)
except TypeError:
msg = (
"the second argument, 'key', has been deprecated and will be "
"removed after the 0.7 release")
warnings.warn(msg, DeprecationWarning)
charm = charm_class(framework, None)
else:
charm = charm_class(framework)
dispatcher.ensure_event_links(charm)
# TODO: Remove the collect_metrics check below as soon as the relevant
# Juju changes are made.
#
# Skip reemission of deferred events for collect-metrics events because
# they do not have the full access to all hook tools.
if not dispatcher.is_restricted_context():
framework.reemit()
_emit_charm_event(charm, dispatcher.event_name)
framework.commit()
finally:
framework.close()
# Copyright 2019-2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta
import pickle
import shutil
import subprocess
import sqlite3
import typing
import yaml
class SQLiteStorage:
DB_LOCK_TIMEOUT = timedelta(hours=1)
def __init__(self, filename):
# The isolation_level argument is set to None such that the implicit
# transaction management behavior of the sqlite3 module is disabled.
self._db = sqlite3.connect(str(filename),
isolation_level=None,
timeout=self.DB_LOCK_TIMEOUT.total_seconds())
self._setup()
def _setup(self):
# Make sure that the database is locked until the connection is closed,
# not until the transaction ends.
self._db.execute("PRAGMA locking_mode=EXCLUSIVE")
c = self._db.execute("BEGIN")
c.execute("SELECT count(name) FROM sqlite_master WHERE type='table' AND name='snapshot'")
if c.fetchone()[0] == 0:
# Keep in mind what might happen if the process dies somewhere below.
# The system must not be rendered permanently broken by that.
self._db.execute("CREATE TABLE snapshot (handle TEXT PRIMARY KEY, data BLOB)")
self._db.execute('''
CREATE TABLE notice (
sequence INTEGER PRIMARY KEY AUTOINCREMENT,
event_path TEXT,
observer_path TEXT,
method_name TEXT)
''')
self._db.commit()
def close(self):
self._db.close()
def commit(self):
self._db.commit()
# There's commit but no rollback. For abort to be supported, we'll need logic that
# can rollback decisions made by third-party code in terms of the internal state
# of objects that have been snapshotted, and hooks to let them know about it and
# take the needed actions to undo their logic until the last snapshot.
# This is doable but will increase significantly the chances for mistakes.
def save_snapshot(self, handle_path: str, snapshot_data: typing.Any) -> None:
"""Part of the Storage API, persist a snapshot data under the given handle.
Args:
handle_path: The string identifying the snapshot.
snapshot_data: The data to be persisted. (as returned by Object.snapshot()). This
might be a dict/tuple/int, but must only contain 'simple' python types.
"""
# Use pickle for serialization, so the value remains portable.
raw_data = pickle.dumps(snapshot_data)
self._db.execute("REPLACE INTO snapshot VALUES (?, ?)", (handle_path, raw_data))
def load_snapshot(self, handle_path: str) -> typing.Any:
"""Part of the Storage API, retrieve a snapshot that was previously saved.
Args:
handle_path: The string identifying the snapshot.
Raises:
NoSnapshotError: if there is no snapshot for the given handle_path.
"""
c = self._db.cursor()
c.execute("SELECT data FROM snapshot WHERE handle=?", (handle_path,))
row = c.fetchone()
if row:
return pickle.loads(row[0])
raise NoSnapshotError(handle_path)
def drop_snapshot(self, handle_path: str):
"""Part of the Storage API, remove a snapshot that was previously saved.
Dropping a snapshot that doesn't exist is treated as a no-op.
"""
self._db.execute("DELETE FROM snapshot WHERE handle=?", (handle_path,))
def list_snapshots(self) -> typing.Generator[str, None, None]:
"""Return the name of all snapshots that are currently saved."""
c = self._db.cursor()
c.execute("SELECT handle FROM snapshot")
while True:
rows = c.fetchmany()
if not rows:
break
for row in rows:
yield row[0]
def save_notice(self, event_path: str, observer_path: str, method_name: str) -> None:
"""Part of the Storage API, record an notice (event and observer)"""
self._db.execute('INSERT INTO notice VALUES (NULL, ?, ?, ?)',
(event_path, observer_path, method_name))
def drop_notice(self, event_path: str, observer_path: str, method_name: str) -> None:
"""Part of the Storage API, remove a notice that was previously recorded."""
self._db.execute('''
DELETE FROM notice
WHERE event_path=?
AND observer_path=?
AND method_name=?
''', (event_path, observer_path, method_name))
def notices(self, event_path: typing.Optional[str]) ->\
typing.Generator[typing.Tuple[str, str, str], None, None]:
"""Part of the Storage API, return all notices that begin with event_path.
Args:
event_path: If supplied, will only yield events that match event_path. If not
supplied (or None/'') will return all events.
Returns:
Iterable of (event_path, observer_path, method_name) tuples
"""
if event_path:
c = self._db.execute('''
SELECT event_path, observer_path, method_name
FROM notice
WHERE event_path=?
ORDER BY sequence
''', (event_path,))
else:
c = self._db.execute('''
SELECT event_path, observer_path, method_name
FROM notice
ORDER BY sequence
''')
while True:
rows = c.fetchmany()
if not rows:
break
for row in rows:
yield tuple(row)
class JujuStorage:
""""Storing the content tracked by the Framework in Juju.
This uses :class:`_JujuStorageBackend` to interact with state-get/state-set
as the way to store state for the framework and for components.
"""
NOTICE_KEY = "#notices#"
def __init__(self, backend: '_JujuStorageBackend' = None):
self._backend = backend
if backend is None:
self._backend = _JujuStorageBackend()
def close(self):
return
def commit(self):
return
def save_snapshot(self, handle_path: str, snapshot_data: typing.Any) -> None:
self._backend.set(handle_path, snapshot_data)
def load_snapshot(self, handle_path):
try:
content = self._backend.get(handle_path)
except KeyError:
raise NoSnapshotError(handle_path)
return content
def drop_snapshot(self, handle_path):
self._backend.delete(handle_path)
def save_notice(self, event_path: str, observer_path: str, method_name: str):
notice_list = self._load_notice_list()
notice_list.append([event_path, observer_path, method_name])
self._save_notice_list(notice_list)
def drop_notice(self, event_path: str, observer_path: str, method_name: str):
notice_list = self._load_notice_list()
notice_list.remove([event_path, observer_path, method_name])
self._save_notice_list(notice_list)
def notices(self, event_path: str):
notice_list = self._load_notice_list()
for row in notice_list:
if row[0] != event_path:
continue
yield tuple(row)
def _load_notice_list(self) -> typing.List[typing.Tuple[str]]:
try:
notice_list = self._backend.get(self.NOTICE_KEY)
except KeyError:
return []
if notice_list is None:
return []
return notice_list
def _save_notice_list(self, notices: typing.List[typing.Tuple[str]]) -> None:
self._backend.set(self.NOTICE_KEY, notices)
class _SimpleLoader(getattr(yaml, 'CSafeLoader', yaml.SafeLoader)):
"""Handle a couple basic python types.
yaml.SafeLoader can handle all the basic int/float/dict/set/etc that we want. The only one
that it *doesn't* handle is tuples. We don't want to support arbitrary types, so we just
subclass SafeLoader and add tuples back in.
"""
# Taken from the example at:
# https://stackoverflow.com/questions/9169025/how-can-i-add-a-python-tuple-to-a-yaml-file-using-pyyaml
construct_python_tuple = yaml.Loader.construct_python_tuple
_SimpleLoader.add_constructor(
u'tag:yaml.org,2002:python/tuple',
_SimpleLoader.construct_python_tuple)
class _SimpleDumper(getattr(yaml, 'CSafeDumper', yaml.SafeDumper)):
"""Add types supported by 'marshal'
YAML can support arbitrary types, but that is generally considered unsafe (like pickle). So
we want to only support dumping out types that are safe to load.
"""
_SimpleDumper.represent_tuple = yaml.Dumper.represent_tuple
_SimpleDumper.add_representer(tuple, _SimpleDumper.represent_tuple)
class _JujuStorageBackend:
"""Implements the interface from the Operator framework to Juju's state-get/set/etc."""
@staticmethod
def is_available() -> bool:
"""Check if Juju state storage is available.
This checks if there is a 'state-get' executable in PATH.
"""
p = shutil.which('state-get')
return p is not None
def set(self, key: str, value: typing.Any) -> None:
"""Set a key to a given value.
Args:
key: The string key that will be used to find the value later
value: Arbitrary content that will be returned by get().
Raises:
CalledProcessError: if 'state-set' returns an error code.
"""
# default_flow_style=None means that it can use Block for
# complex types (types that have nested types) but use flow
# for simple types (like an array). Not all versions of PyYAML
# have the same default style.
encoded_value = yaml.dump(value, Dumper=_SimpleDumper, default_flow_style=None)
content = yaml.dump(
{key: encoded_value}, encoding='utf-8', default_style='|',
default_flow_style=False,
Dumper=_SimpleDumper)
subprocess.run(["state-set", "--file", "-"], input=content, check=True)
def get(self, key: str) -> typing.Any:
"""Get the bytes value associated with a given key.
Args:
key: The string key that will be used to find the value
Raises:
CalledProcessError: if 'state-get' returns an error code.
"""
# We don't capture stderr here so it can end up in debug logs.
p = subprocess.run(
["state-get", key],
stdout=subprocess.PIPE,
check=True,
)
if p.stdout == b'' or p.stdout == b'\n':
raise KeyError(key)
return yaml.load(p.stdout, Loader=_SimpleLoader)
def delete(self, key: str) -> None:
"""Remove a key from being tracked.
Args:
key: The key to stop storing
Raises:
CalledProcessError: if 'state-delete' returns an error code.
"""
subprocess.run(["state-delete", key], check=True)
class NoSnapshotError(Exception):
def __init__(self, handle_path):
self.handle_path = handle_path
def __str__(self):
return 'no snapshot data found for {} object'.format(self.handle_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment