blob: 880c5cb0cfaefac520164af107686d80ca16cabe [file] [log] [blame]
# Copyright 2020 Canonical Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase, mock
from mock import mock_open
from n2vc.provisioner import SSHProvisioner
from paramiko.ssh_exception import SSHException
class ProvisionerTest(TestCase):
def setUp(self):
self.provisioner = SSHProvisioner(None, None, None)
@mock.patch("n2vc.provisioner.os.path.exists")
@mock.patch("n2vc.provisioner.paramiko.RSAKey")
@mock.patch("n2vc.provisioner.paramiko.SSHClient")
@mock.patch("builtins.open", new_callable=mock_open, read_data="data")
def test__get_ssh_client(self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os):
mock_instance = mock_sshclient.return_value
sshclient = self.provisioner._get_ssh_client()
self.assertEqual(mock_instance, sshclient)
self.assertEqual(
1,
mock_instance.set_missing_host_key_policy.call_count,
"Missing host key call count",
)
self.assertEqual(1, mock_instance.connect.call_count, "Connect call count")
@mock.patch("n2vc.provisioner.os.path.exists")
@mock.patch("n2vc.provisioner.paramiko.RSAKey")
@mock.patch("n2vc.provisioner.paramiko.SSHClient")
@mock.patch("builtins.open", new_callable=mock_open, read_data="data")
def test__get_ssh_client_no_connection(
self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os
):
mock_instance = mock_sshclient.return_value
mock_instance.method_inside_someobject.side_effect = ["something"]
mock_instance.connect.side_effect = SSHException()
self.assertRaises(SSHException, self.provisioner._get_ssh_client)
self.assertEqual(
1,
mock_instance.set_missing_host_key_policy.call_count,
"Missing host key call count",
)
self.assertEqual(1, mock_instance.connect.call_count, "Connect call count")
@mock.patch("n2vc.provisioner.os.path.exists")
@mock.patch("n2vc.provisioner.paramiko.RSAKey")
@mock.patch("n2vc.provisioner.paramiko.SSHClient")
@mock.patch("builtins.open", new_callable=mock_open, read_data="data")
def test__get_ssh_client_bad_banner(
self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os
):
mock_instance = mock_sshclient.return_value
mock_instance.method_inside_someobject.side_effect = ["something"]
mock_instance.connect.side_effect = [
SSHException("Error reading SSH protocol banner"),
None,
None,
]
sshclient = self.provisioner._get_ssh_client()
self.assertEqual(mock_instance, sshclient)
self.assertEqual(
1,
mock_instance.set_missing_host_key_policy.call_count,
"Missing host key call count",
)
self.assertEqual(
3, mock_instance.connect.call_count, "Should attempt 3 connections"
)
@mock.patch("time.sleep", autospec=True)
@mock.patch("n2vc.provisioner.os.path.exists")
@mock.patch("n2vc.provisioner.paramiko.RSAKey")
@mock.patch("n2vc.provisioner.paramiko.SSHClient")
@mock.patch("builtins.open", new_callable=mock_open, read_data="data")
def test__get_ssh_client_unable_to_connect(
self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os, _mock_sleep
):
mock_instance = mock_sshclient.return_value
mock_instance.connect.side_effect = Exception("Unable to connect to port")
self.assertRaises(Exception, self.provisioner._get_ssh_client)
self.assertEqual(
1,
mock_instance.set_missing_host_key_policy.call_count,
"Missing host key call count",
)
self.assertEqual(
11, mock_instance.connect.call_count, "Should attempt 11 connections"
)
@mock.patch("time.sleep", autospec=True)
@mock.patch("n2vc.provisioner.os.path.exists")
@mock.patch("n2vc.provisioner.paramiko.RSAKey")
@mock.patch("n2vc.provisioner.paramiko.SSHClient")
@mock.patch("builtins.open", new_callable=mock_open, read_data="data")
def test__get_ssh_client_unable_to_connect_once(
self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os, _mock_sleep
):
mock_instance = mock_sshclient.return_value
mock_instance.connect.side_effect = [
Exception("Unable to connect to port"),
None,
]
sshclient = self.provisioner._get_ssh_client()
self.assertEqual(mock_instance, sshclient)
self.assertEqual(
1,
mock_instance.set_missing_host_key_policy.call_count,
"Missing host key call count",
)
self.assertEqual(
2, mock_instance.connect.call_count, "Should attempt 2 connections"
)
@mock.patch("n2vc.provisioner.os.path.exists")
@mock.patch("n2vc.provisioner.paramiko.RSAKey")
@mock.patch("n2vc.provisioner.paramiko.SSHClient")
@mock.patch("builtins.open", new_callable=mock_open, read_data="data")
def test__get_ssh_client_other_exception(
self, _mock_open, mock_sshclient, _mock_rsakey, _mock_os
):
mock_instance = mock_sshclient.return_value
mock_instance.connect.side_effect = Exception()
self.assertRaises(Exception, self.provisioner._get_ssh_client)
self.assertEqual(
1,
mock_instance.set_missing_host_key_policy.call_count,
"Missing host key call count",
)
self.assertEqual(
1, mock_instance.connect.call_count, "Should only attempt 1 connection"
)
#