Allow specifying ssh key file for compute nodes and network controller
[osm/openvim.git] / osm_openvim / host_thread.py
index bf7a6da..1a03c7f 100644 (file)
@@ -34,20 +34,21 @@ import threading
 import time
 import Queue
 import paramiko
-from jsonschema import validate as js_v, exceptions as js_e
-#import libvirt
+# import subprocess
+# import libvirt
 import imp
-from vim_schema import localinfo_schema, hostinfo_schema
 import random
 import os
 import logging
+from jsonschema import validate as js_v, exceptions as js_e
+from vim_schema import localinfo_schema, hostinfo_schema
 
 
 class host_thread(threading.Thread):
     lvirt_module = None
 
     def __init__(self, name, host, user, db, db_lock, test, image_path, host_id, version, develop_mode,
-                 develop_bridge_iface, logger_name=None, debug=None):
+                 develop_bridge_iface, password=None, keyfile = None, logger_name=None, debug=None):
         '''Init a thread.
         Arguments:
             'id' number of thead
@@ -62,6 +63,8 @@ class host_thread(threading.Thread):
         self.db = db
         self.db_lock = db_lock
         self.test = test
+        self.password = password
+        self.keyfile =  keyfile
         self.localinfo_dirty = False
 
         if not test and not host_thread.lvirt_module:
@@ -97,33 +100,38 @@ class host_thread(threading.Thread):
         self.queueLock = threading.Lock()
         self.taskQueue = Queue.Queue(2000)
         self.ssh_conn = None
+        self.lvirt_conn_uri = "qemu+ssh://{user}@{host}/system?no_tty=1&no_verify=1".format(
+            user=self.user, host=self.host)
+        if keyfile:
+            self.lvirt_conn_uri += "&keyfile=" + keyfile
 
     def ssh_connect(self):
         try:
-            #Connect SSH
+            # Connect SSH
             self.ssh_conn = paramiko.SSHClient()
             self.ssh_conn.set_missing_host_key_policy(paramiko.AutoAddPolicy())
             self.ssh_conn.load_system_host_keys()
-            self.ssh_conn.connect(self.host, username=self.user, timeout=10) #, None)
+            self.ssh_conn.connect(self.host, username=self.user, password=self.password, key_filename=self.keyfile,
+                                  timeout=10) #, None)
         except paramiko.ssh_exception.SSHException as e:
             text = e.args[0]
             self.logger.error("ssh_connect ssh Exception: " + text)
-        
+
     def load_localinfo(self):
         if not self.test:
             try:
-                #Connect SSH
+                # Connect SSH
                 self.ssh_connect()
-    
+
                 command = 'mkdir -p ' +  self.image_path
-                #print self.name, ': command:', command
+                # print self.name, ': command:', command
                 (_, stdout, stderr) = self.ssh_conn.exec_command(command)
                 content = stderr.read()
                 if len(content) > 0:
                     self.logger.error("command: '%s' stderr: '%s'", command, content)
 
                 command = 'cat ' +  self.image_path + '/.openvim.yaml'
-                #print self.name, ': command:', command
+                # print self.name, ': command:', command
                 (_, stdout, stderr) = self.ssh_conn.exec_command(command)
                 content = stdout.read()
                 if len(content) == 0:
@@ -136,7 +144,7 @@ class host_thread(threading.Thread):
                     self.localinfo['server_files'] = {}
                 self.logger.debug("localinfo load from host")
                 return
-    
+
             except paramiko.ssh_exception.SSHException as e:
                 text = e.args[0]
                 self.logger.error("load_localinfo ssh Exception: " + text)
@@ -1610,12 +1618,12 @@ class host_thread(threading.Thread):
                             # VIR_DOMAIN_SHUTOFF = 5
                             # VIR_DOMAIN_CRASHED = 6
                             # VIR_DOMAIN_PMSUSPENDED = 7   #TODO suspended
-    
+
         if self.test or len(self.server_status)==0:
-            return            
-        
+            return
+
         try:
-            conn = host_thread.lvirt_module.open("qemu+ssh://"+self.user+"@"+self.host+"/system")
+            conn = host_thread.lvirt_module.open(self.lvirt_conn_uri)
             domains=  conn.listAllDomains() 
             domain_dict={}
             for domain in domains:
@@ -1704,7 +1712,7 @@ class host_thread(threading.Thread):
                 self.create_image(None, req)
         else:
             try:
-                conn = host_thread.lvirt_module.open("qemu+ssh://"+self.user+"@"+self.host+"/system")
+                conn = host_thread.lvirt_module.open(self.lvirt_conn_uri)
                 try:
                     dom = conn.lookupByUUIDString(server_id)
                 except host_thread.lvirt_module.libvirtError as e:
@@ -1898,7 +1906,7 @@ class host_thread(threading.Thread):
             return 0, None
         try:
             if not lib_conn:
-                conn = host_thread.lvirt_module.open("qemu+ssh://"+self.user+"@"+self.host+"/system")
+                conn = host_thread.lvirt_module.open(self.lvirt_conn_uri)
             else:
                 conn = lib_conn
                 
@@ -1995,12 +2003,12 @@ class host_thread(threading.Thread):
             xml.append("<interface type='hostdev' managed='yes'>")
             xml.append("  <mac address='" +port['mac']+ "'/>")
             xml.append("  <source>"+ self.pci2xml(port['pci'])+"\n  </source>")
-            xml.append('</interface>')                
+            xml.append('</interface>')
 
             
             try:
                 conn=None
-                conn = host_thread.lvirt_module.open("qemu+ssh://"+self.user+"@"+self.host+"/system")
+                conn = host_thread.lvirt_module.open(self.lvirt_conn_uri)
                 dom = conn.lookupByUUIDString(port["instance_id"])
                 if old_net:
                     text="\n".join(xml)