Implement scp functionality (#149)
authorCynerva <cynerva@gmail.com>
Thu, 29 Jun 2017 18:48:40 +0000 (13:48 -0500)
committerCory Johns <johnsca@gmail.com>
Thu, 29 Jun 2017 18:48:40 +0000 (14:48 -0400)
* Redefine SCP interface

* Implement Machine.scp_to and Machine.scp_from

* Implement Unit.scp_to and Unit.scp_from

* Fix Unit.machine to better handle missing machines

* Use os.path.expanduser for scp key path

* Use local-cloud address if public address is not found

* Rename Machine.public_address to Machine.dns_name to match CLI

* fix scp tests not waiting long enough

juju/machine.py
juju/unit.py
tests/integration/test_machine.py
tests/integration/test_unit.py

index 18c333c..23b41c6 100644 (file)
@@ -1,9 +1,12 @@
+import asyncio
 import logging
+import os
 
 from dateutil.parser import parse as parse_date
 
 from . import model, utils
 from .client import client
+from .errors import JujuError
 
 log = logging.getLogger(__name__)
 
@@ -124,20 +127,56 @@ class Machine(model.ModelEntity):
         )
         return await self.ann_facade.Set([ann])
 
-    def scp(
-            self, source_path, user=None, destination_path=None, proxy=False,
-            scp_opts=None):
+    async def scp_to(self, source, destination, user='ubuntu', proxy=False,
+                     scp_opts=''):
         """Transfer files to this machine.
 
-        :param str source_path: Path of file(s) to transfer
+        :param str source: Local path of file(s) to transfer
+        :param str destination: Remote destination of transferred files
         :param str user: Remote username
-        :param str destination_path: Destination of transferred files on
-            remote machine
         :param bool proxy: Proxy through the Juju API server
         :param str scp_opts: Additional options to the `scp` command
+        """
+        if proxy:
+            raise NotImplementedError('proxy option is not implemented')
+
+        address = self.dns_name
+        destination = '%s@%s:%s' % (user, address, destination)
+        await self._scp(source, destination, scp_opts)
 
+    async def scp_from(self, source, destination, user='ubuntu', proxy=False,
+                       scp_opts=''):
+        """Transfer files from this machine.
+
+        :param str source: Remote path of file(s) to transfer
+        :param str destination: Local destination of transferred files
+        :param str user: Remote username
+        :param bool proxy: Proxy through the Juju API server
+        :param str scp_opts: Additional options to the `scp` command
         """
-        raise NotImplementedError()
+        if proxy:
+            raise NotImplementedError('proxy option is not implemented')
+
+        address = self.dns_name
+        source = '%s@%s:%s' % (user, address, source)
+        await self._scp(source, destination, scp_opts)
+
+    async def _scp(self, source, destination, scp_opts):
+        """ Execute an scp command. Requires a fully qualified source and
+        destination.
+        """
+        cmd = [
+            'scp',
+            '-i', os.path.expanduser('~/.local/share/juju/ssh/juju_id_rsa'),
+            '-o', 'StrictHostKeyChecking=no',
+            source, destination
+        ]
+        cmd += scp_opts.split()
+        loop = self.model.loop
+        process = await asyncio.create_subprocess_exec(*cmd, loop=loop)
+        await process.wait()
+        if process.returncode != 0:
+            raise JujuError("command failed: %s" % cmd)
 
     def ssh(
             self, command, user=None, proxy=False, ssh_opts=None):
@@ -206,3 +245,18 @@ class Machine(model.ModelEntity):
 
         """
         return parse_date(self.safe_data['instance-status']['since'])
+
+    @property
+    def dns_name(self):
+        """Get the DNS name for this machine. This is a best guess based on the
+        addresses available in current data.
+
+        May return None if no suitable address is found.
+        """
+        for scope in ['public', 'local-cloud']:
+            addresses = self.safe_data['addresses'] or []
+            addresses = [address for address in addresses
+                         if address['scope'] == scope]
+            if addresses:
+                return addresses[0]['value']
+        return None
index 0f2a51c..fc597bf 100644 (file)
@@ -51,6 +51,17 @@ class Unit(model.ModelEntity):
         """
         return self.safe_data['workload-status']['message']
 
+    @property
+    def machine(self):
+        """Get the machine object for this unit.
+
+        """
+        machine_id = self.safe_data['machine-id']
+        if machine_id:
+            return self.model.machines.get(machine_id, None)
+        else:
+            return None
+
     @property
     def public_address(self):
         """ Get the public address.
@@ -163,20 +174,31 @@ class Unit(model.ModelEntity):
         # action is complete, rather than just being in the model
         return await self.model._wait_for_new('action', action_id)
 
-    def scp(
-            self, source_path, user=None, destination_path=None, proxy=False,
-            scp_opts=None):
+    async def scp_to(self, source, destination, user='ubuntu', proxy=False,
+                     scp_opts=''):
         """Transfer files to this unit.
 
-        :param str source_path: Path of file(s) to transfer
+        :param str source: Local path of file(s) to transfer
+        :param str destination: Remote destination of transferred files
         :param str user: Remote username
-        :param str destination_path: Destination of transferred files on
-            remote machine
         :param bool proxy: Proxy through the Juju API server
         :param str scp_opts: Additional options to the `scp` command
+        """
+        await self.machine.scp_to(source, destination, user=user, proxy=proxy,
+                                  scp_opts=scp_opts)
 
+    async def scp_from(self, source, destination, user='ubuntu', proxy=False,
+                       scp_opts=''):
+        """Transfer files from this unit.
+
+        :param str source: Remote path of file(s) to transfer
+        :param str destination: Local destination of transferred files
+        :param str user: Remote username
+        :param bool proxy: Proxy through the Juju API server
+        :param str scp_opts: Additional options to the `scp` command
         """
-        raise NotImplementedError()
+        await self.machine.scp_from(source, destination, user=user,
+                                    proxy=proxy, scp_opts=scp_opts)
 
     def set_meter_status(self):
         """Set the meter status on this unit.
index 499a7d3..60de035 100644 (file)
@@ -1,7 +1,8 @@
 import asyncio
-
 import pytest
 
+from tempfile import NamedTemporaryFile
+
 from .. import base
 
 
@@ -35,3 +36,27 @@ async def test_status(event_loop):
         assert machine.status_message.lower() == 'running'
         assert machine.agent_status == 'started'
         assert machine.agent_version.major >= 2
+
+
+@base.bootstrapped
+@pytest.mark.asyncio
+async def test_scp(event_loop):
+    async with base.CleanModel() as model:
+        await model.add_machine()
+        await asyncio.wait_for(
+            model.block_until(lambda: model.machines),
+            timeout=240)
+        machine = model.machines['0']
+        await asyncio.wait_for(
+            model.block_until(lambda: (machine.status == 'running' and
+                                       machine.agent_status == 'started')),
+            timeout=480)
+
+        with NamedTemporaryFile() as f:
+            f.write(b'testcontents')
+            f.flush()
+            await machine.scp_to(f.name, 'testfile')
+
+        with NamedTemporaryFile() as f:
+            await machine.scp_from('testfile', f.name)
+            assert f.read() == b'testcontents'
index e9116ce..1604c31 100644 (file)
@@ -1,5 +1,8 @@
+import asyncio
 import pytest
 
+from tempfile import NamedTemporaryFile
+
 from .. import base
 
 
@@ -44,3 +47,32 @@ async def test_run_action(event_loop):
             action = await run_action(unit)
             assert action.results == {'dir': '/var/git/myrepo.git'}
             break
+
+
+@base.bootstrapped
+@pytest.mark.asyncio
+async def test_scp(event_loop):
+    async with base.CleanModel() as model:
+        app = await model.deploy('ubuntu')
+
+        await asyncio.wait_for(
+            model.block_until(lambda: app.units),
+            timeout=60)
+        unit = app.units[0]
+        await asyncio.wait_for(
+            model.block_until(lambda: unit.machine),
+            timeout=60)
+        machine = unit.machine
+        await asyncio.wait_for(
+            model.block_until(lambda: (machine.status == 'running' and
+                                       machine.agent_status == 'started')),
+            timeout=480)
+
+        with NamedTemporaryFile() as f:
+            f.write(b'testcontents')
+            f.flush()
+            await unit.scp_to(f.name, 'testfile')
+
+        with NamedTemporaryFile() as f:
+            await unit.scp_from('testfile', f.name)
+            assert f.read() == b'testcontents'