Enable lint, flake8 and unit tests
[osm/N2VC.git] / n2vc / tests / unit / test_provisioner.py
diff --git a/n2vc/tests/unit/test_provisioner.py b/n2vc/tests/unit/test_provisioner.py
new file mode 100644 (file)
index 0000000..880c5cb
--- /dev/null
@@ -0,0 +1,158 @@
+# Copyright 2020 Canonical Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+#     Unless required by applicable law or agreed to in writing, software
+#     distributed under the License is distributed on an "AS IS" BASIS,
+#     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#     See the License for the specific language governing permissions and
+#     limitations under the License.
+
+from unittest import TestCase, mock
+
+from mock import mock_open
+from n2vc.provisioner import SSHProvisioner
+from paramiko.ssh_exception import SSHException
+
+
+class ProvisionerTest(TestCase):
+    def setUp(self):
+        self.provisioner = SSHProvisioner(None, None, None)
+
+    @mock.patch("n2vc.provisioner.os.path.exists")
+    @mock.patch("n2vc.provisioner.paramiko.RSAKey")
+    @mock.patch("n2vc.provisioner.paramiko.SSHClient")
+    @mock.patch("builtins.open", new_callable=mock_open, read_data="data")
+    def test__get_ssh_client(self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os):
+        mock_instance = mock_sshclient.return_value
+        sshclient = self.provisioner._get_ssh_client()
+        self.assertEqual(mock_instance, sshclient)
+        self.assertEqual(
+            1,
+            mock_instance.set_missing_host_key_policy.call_count,
+            "Missing host key call count",
+        )
+        self.assertEqual(1, mock_instance.connect.call_count, "Connect call count")
+
+    @mock.patch("n2vc.provisioner.os.path.exists")
+    @mock.patch("n2vc.provisioner.paramiko.RSAKey")
+    @mock.patch("n2vc.provisioner.paramiko.SSHClient")
+    @mock.patch("builtins.open", new_callable=mock_open, read_data="data")
+    def test__get_ssh_client_no_connection(
+        self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os
+    ):
+
+        mock_instance = mock_sshclient.return_value
+        mock_instance.method_inside_someobject.side_effect = ["something"]
+        mock_instance.connect.side_effect = SSHException()
+
+        self.assertRaises(SSHException, self.provisioner._get_ssh_client)
+        self.assertEqual(
+            1,
+            mock_instance.set_missing_host_key_policy.call_count,
+            "Missing host key call count",
+        )
+        self.assertEqual(1, mock_instance.connect.call_count, "Connect call count")
+
+    @mock.patch("n2vc.provisioner.os.path.exists")
+    @mock.patch("n2vc.provisioner.paramiko.RSAKey")
+    @mock.patch("n2vc.provisioner.paramiko.SSHClient")
+    @mock.patch("builtins.open", new_callable=mock_open, read_data="data")
+    def test__get_ssh_client_bad_banner(
+        self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os
+    ):
+
+        mock_instance = mock_sshclient.return_value
+        mock_instance.method_inside_someobject.side_effect = ["something"]
+        mock_instance.connect.side_effect = [
+            SSHException("Error reading SSH protocol banner"),
+            None,
+            None,
+        ]
+
+        sshclient = self.provisioner._get_ssh_client()
+        self.assertEqual(mock_instance, sshclient)
+        self.assertEqual(
+            1,
+            mock_instance.set_missing_host_key_policy.call_count,
+            "Missing host key call count",
+        )
+        self.assertEqual(
+            3, mock_instance.connect.call_count, "Should attempt 3 connections"
+        )
+
+    @mock.patch("time.sleep", autospec=True)
+    @mock.patch("n2vc.provisioner.os.path.exists")
+    @mock.patch("n2vc.provisioner.paramiko.RSAKey")
+    @mock.patch("n2vc.provisioner.paramiko.SSHClient")
+    @mock.patch("builtins.open", new_callable=mock_open, read_data="data")
+    def test__get_ssh_client_unable_to_connect(
+        self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os, _mock_sleep
+    ):
+
+        mock_instance = mock_sshclient.return_value
+        mock_instance.connect.side_effect = Exception("Unable to connect to port")
+
+        self.assertRaises(Exception, self.provisioner._get_ssh_client)
+        self.assertEqual(
+            1,
+            mock_instance.set_missing_host_key_policy.call_count,
+            "Missing host key call count",
+        )
+        self.assertEqual(
+            11, mock_instance.connect.call_count, "Should attempt 11 connections"
+        )
+
+    @mock.patch("time.sleep", autospec=True)
+    @mock.patch("n2vc.provisioner.os.path.exists")
+    @mock.patch("n2vc.provisioner.paramiko.RSAKey")
+    @mock.patch("n2vc.provisioner.paramiko.SSHClient")
+    @mock.patch("builtins.open", new_callable=mock_open, read_data="data")
+    def test__get_ssh_client_unable_to_connect_once(
+        self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os, _mock_sleep
+    ):
+
+        mock_instance = mock_sshclient.return_value
+        mock_instance.connect.side_effect = [
+            Exception("Unable to connect to port"),
+            None,
+        ]
+
+        sshclient = self.provisioner._get_ssh_client()
+        self.assertEqual(mock_instance, sshclient)
+        self.assertEqual(
+            1,
+            mock_instance.set_missing_host_key_policy.call_count,
+            "Missing host key call count",
+        )
+        self.assertEqual(
+            2, mock_instance.connect.call_count, "Should attempt 2 connections"
+        )
+
+    @mock.patch("n2vc.provisioner.os.path.exists")
+    @mock.patch("n2vc.provisioner.paramiko.RSAKey")
+    @mock.patch("n2vc.provisioner.paramiko.SSHClient")
+    @mock.patch("builtins.open", new_callable=mock_open, read_data="data")
+    def test__get_ssh_client_other_exception(
+        self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os
+    ):
+
+        mock_instance = mock_sshclient.return_value
+        mock_instance.connect.side_effect = Exception()
+
+        self.assertRaises(Exception, self.provisioner._get_ssh_client)
+        self.assertEqual(
+            1,
+            mock_instance.set_missing_host_key_policy.call_count,
+            "Missing host key call count",
+        )
+        self.assertEqual(
+            1, mock_instance.connect.call_count, "Should only attempt 1 connection"
+        )
+
+
+#