From fbd25b458d70f0ca0743be60ff3d6ab21335707f Mon Sep 17 00:00:00 2001 From: Cynerva Date: Thu, 29 Jun 2017 13:48:40 -0500 Subject: [PATCH] Implement scp functionality (#149) * Redefine SCP interface * Implement Machine.scp_to and Machine.scp_from * Implement Unit.scp_to and Unit.scp_from * Fix Unit.machine to better handle missing machines * Use os.path.expanduser for scp key path * Use local-cloud address if public address is not found * Rename Machine.public_address to Machine.dns_name to match CLI * fix scp tests not waiting long enough --- juju/machine.py | 68 +++++++++++++++++++++++++++---- juju/unit.py | 36 ++++++++++++---- tests/integration/test_machine.py | 27 +++++++++++- tests/integration/test_unit.py | 32 +++++++++++++++ 4 files changed, 148 insertions(+), 15 deletions(-) diff --git a/juju/machine.py b/juju/machine.py index 18c333c..23b41c6 100644 --- a/juju/machine.py +++ b/juju/machine.py @@ -1,9 +1,12 @@ +import asyncio import logging +import os from dateutil.parser import parse as parse_date from . import model, utils from .client import client +from .errors import JujuError log = logging.getLogger(__name__) @@ -124,20 +127,56 @@ class Machine(model.ModelEntity): ) return await self.ann_facade.Set([ann]) - def scp( - self, source_path, user=None, destination_path=None, proxy=False, - scp_opts=None): + async def scp_to(self, source, destination, user='ubuntu', proxy=False, + scp_opts=''): """Transfer files to this machine. - :param str source_path: Path of file(s) to transfer + :param str source: Local path of file(s) to transfer + :param str destination: Remote destination of transferred files :param str user: Remote username - :param str destination_path: Destination of transferred files on - remote machine :param bool proxy: Proxy through the Juju API server :param str scp_opts: Additional options to the `scp` command + """ + if proxy: + raise NotImplementedError('proxy option is not implemented') + + address = self.dns_name + destination = '%s@%s:%s' % (user, address, destination) + await self._scp(source, destination, scp_opts) + async def scp_from(self, source, destination, user='ubuntu', proxy=False, + scp_opts=''): + """Transfer files from this machine. + + :param str source: Remote path of file(s) to transfer + :param str destination: Local destination of transferred files + :param str user: Remote username + :param bool proxy: Proxy through the Juju API server + :param str scp_opts: Additional options to the `scp` command """ - raise NotImplementedError() + if proxy: + raise NotImplementedError('proxy option is not implemented') + + address = self.dns_name + source = '%s@%s:%s' % (user, address, source) + await self._scp(source, destination, scp_opts) + + async def _scp(self, source, destination, scp_opts): + """ Execute an scp command. Requires a fully qualified source and + destination. + """ + cmd = [ + 'scp', + '-i', os.path.expanduser('~/.local/share/juju/ssh/juju_id_rsa'), + '-o', 'StrictHostKeyChecking=no', + source, destination + ] + cmd += scp_opts.split() + loop = self.model.loop + process = await asyncio.create_subprocess_exec(*cmd, loop=loop) + await process.wait() + if process.returncode != 0: + raise JujuError("command failed: %s" % cmd) def ssh( self, command, user=None, proxy=False, ssh_opts=None): @@ -206,3 +245,18 @@ class Machine(model.ModelEntity): """ return parse_date(self.safe_data['instance-status']['since']) + + @property + def dns_name(self): + """Get the DNS name for this machine. This is a best guess based on the + addresses available in current data. + + May return None if no suitable address is found. + """ + for scope in ['public', 'local-cloud']: + addresses = self.safe_data['addresses'] or [] + addresses = [address for address in addresses + if address['scope'] == scope] + if addresses: + return addresses[0]['value'] + return None diff --git a/juju/unit.py b/juju/unit.py index 0f2a51c..fc597bf 100644 --- a/juju/unit.py +++ b/juju/unit.py @@ -51,6 +51,17 @@ class Unit(model.ModelEntity): """ return self.safe_data['workload-status']['message'] + @property + def machine(self): + """Get the machine object for this unit. + + """ + machine_id = self.safe_data['machine-id'] + if machine_id: + return self.model.machines.get(machine_id, None) + else: + return None + @property def public_address(self): """ Get the public address. @@ -163,20 +174,31 @@ class Unit(model.ModelEntity): # action is complete, rather than just being in the model return await self.model._wait_for_new('action', action_id) - def scp( - self, source_path, user=None, destination_path=None, proxy=False, - scp_opts=None): + async def scp_to(self, source, destination, user='ubuntu', proxy=False, + scp_opts=''): """Transfer files to this unit. - :param str source_path: Path of file(s) to transfer + :param str source: Local path of file(s) to transfer + :param str destination: Remote destination of transferred files :param str user: Remote username - :param str destination_path: Destination of transferred files on - remote machine :param bool proxy: Proxy through the Juju API server :param str scp_opts: Additional options to the `scp` command + """ + await self.machine.scp_to(source, destination, user=user, proxy=proxy, + scp_opts=scp_opts) + async def scp_from(self, source, destination, user='ubuntu', proxy=False, + scp_opts=''): + """Transfer files from this unit. + + :param str source: Remote path of file(s) to transfer + :param str destination: Local destination of transferred files + :param str user: Remote username + :param bool proxy: Proxy through the Juju API server + :param str scp_opts: Additional options to the `scp` command """ - raise NotImplementedError() + await self.machine.scp_from(source, destination, user=user, + proxy=proxy, scp_opts=scp_opts) def set_meter_status(self): """Set the meter status on this unit. diff --git a/tests/integration/test_machine.py b/tests/integration/test_machine.py index 499a7d3..60de035 100644 --- a/tests/integration/test_machine.py +++ b/tests/integration/test_machine.py @@ -1,7 +1,8 @@ import asyncio - import pytest +from tempfile import NamedTemporaryFile + from .. import base @@ -35,3 +36,27 @@ async def test_status(event_loop): assert machine.status_message.lower() == 'running' assert machine.agent_status == 'started' assert machine.agent_version.major >= 2 + + +@base.bootstrapped +@pytest.mark.asyncio +async def test_scp(event_loop): + async with base.CleanModel() as model: + await model.add_machine() + await asyncio.wait_for( + model.block_until(lambda: model.machines), + timeout=240) + machine = model.machines['0'] + await asyncio.wait_for( + model.block_until(lambda: (machine.status == 'running' and + machine.agent_status == 'started')), + timeout=480) + + with NamedTemporaryFile() as f: + f.write(b'testcontents') + f.flush() + await machine.scp_to(f.name, 'testfile') + + with NamedTemporaryFile() as f: + await machine.scp_from('testfile', f.name) + assert f.read() == b'testcontents' diff --git a/tests/integration/test_unit.py b/tests/integration/test_unit.py index e9116ce..1604c31 100644 --- a/tests/integration/test_unit.py +++ b/tests/integration/test_unit.py @@ -1,5 +1,8 @@ +import asyncio import pytest +from tempfile import NamedTemporaryFile + from .. import base @@ -44,3 +47,32 @@ async def test_run_action(event_loop): action = await run_action(unit) assert action.results == {'dir': '/var/git/myrepo.git'} break + + +@base.bootstrapped +@pytest.mark.asyncio +async def test_scp(event_loop): + async with base.CleanModel() as model: + app = await model.deploy('ubuntu') + + await asyncio.wait_for( + model.block_until(lambda: app.units), + timeout=60) + unit = app.units[0] + await asyncio.wait_for( + model.block_until(lambda: unit.machine), + timeout=60) + machine = unit.machine + await asyncio.wait_for( + model.block_until(lambda: (machine.status == 'running' and + machine.agent_status == 'started')), + timeout=480) + + with NamedTemporaryFile() as f: + f.write(b'testcontents') + f.flush() + await unit.scp_to(f.name, 'testfile') + + with NamedTemporaryFile() as f: + await unit.scp_from('testfile', f.name) + assert f.read() == b'testcontents' -- 2.25.1