aboutsummaryrefslogtreecommitdiffstats
path: root/yardstick/ssh.py
diff options
context:
space:
mode:
Diffstat (limited to 'yardstick/ssh.py')
-rw-r--r--yardstick/ssh.py147
1 files changed, 129 insertions, 18 deletions
diff --git a/yardstick/ssh.py b/yardstick/ssh.py
index cf9adf0dc..8ac3eaa3a 100644
--- a/yardstick/ssh.py
+++ b/yardstick/ssh.py
@@ -77,8 +77,8 @@ from oslo_utils import encodeutils
from scp import SCPClient
import six
-
-SSH_PORT = paramiko.config.SSH_PORT
+from yardstick.common.utils import try_int
+from yardstick.network_services.utils import provision_tool
class SSHError(Exception):
@@ -92,7 +92,26 @@ class SSHTimeout(SSHError):
class SSH(object):
"""Represent ssh connection."""
- def __init__(self, user, host, port=SSH_PORT, pkey=None,
+ SSH_PORT = paramiko.config.SSH_PORT
+
+ @staticmethod
+ def gen_keys(key_filename, bit_count=2048):
+ rsa_key = paramiko.RSAKey.generate(bits=bit_count, progress_func=None)
+ rsa_key.write_private_key_file(key_filename)
+ print("Writing %s ..." % key_filename)
+ with open('.'.join([key_filename, "pub"]), "w") as pubkey_file:
+ pubkey_file.write(rsa_key.get_name())
+ pubkey_file.write(' ')
+ pubkey_file.write(rsa_key.get_base64())
+ pubkey_file.write('\n')
+
+ @staticmethod
+ def get_class():
+ # must return static class name, anything else refers to the calling class
+ # i.e. the subclass, not the superclass
+ return SSH
+
+ def __init__(self, user, host, port=None, pkey=None,
key_filename=None, password=None, name=None):
"""Initialize SSH client.
@@ -115,7 +134,7 @@ class SSH(object):
self.log.debug("user:%s host:%s", user, host)
# we may get text port from YAML, convert to int
- self.port = int(port)
+ self.port = try_int(port, self.SSH_PORT)
self.pkey = self._get_pkey(pkey) if pkey else None
self.password = password
self.key_filename = key_filename
@@ -129,21 +148,25 @@ class SSH(object):
logging.getLogger("paramiko").setLevel(logging.WARN)
@classmethod
- def from_node(cls, node, overrides=None, defaults=None):
+ def args_from_node(cls, node, overrides=None, defaults=None):
if overrides is None:
overrides = {}
if defaults is None:
defaults = {}
params = ChainMap(overrides, node, defaults)
- return cls(
- user=params['user'],
- host=params['ip'],
- # paramiko doesn't like None default, requires SSH_PORT default
- port=params.get('ssh_port', SSH_PORT),
- pkey=params.get('pkey'),
- key_filename=params.get('key_filename'),
- password=params.get('password'),
- name=params.get('name'))
+ return {
+ 'user': params['user'],
+ 'host': params['ip'],
+ 'port': params.get('ssh_port', cls.SSH_PORT),
+ 'pkey': params.get('pkey'),
+ 'key_filename': params.get('key_filename'),
+ 'password': params.get('password'),
+ 'name': params.get('name'),
+ }
+
+ @classmethod
+ def from_node(cls, node, overrides=None, defaults=None):
+ return cls(**cls.args_from_node(node, overrides, defaults))
def _get_pkey(self, key):
if isinstance(key, six.string_types):
@@ -156,8 +179,12 @@ class SSH(object):
errors.append(e)
raise SSHError("Invalid pkey: %s" % (errors))
+ @property
+ def is_connected(self):
+ return bool(self._client)
+
def _get_client(self):
- if self._client:
+ if self.is_connected:
return self._client
try:
self._client = paramiko.SSHClient()
@@ -176,9 +203,24 @@ class SSH(object):
raise SSHError(message % {"exception": e,
"exception_type": type(e)})
+ def _make_dict(self):
+ return {
+ 'user': self.user,
+ 'host': self.host,
+ 'port': self.port,
+ 'pkey': self.pkey,
+ 'key_filename': self.key_filename,
+ 'password': self.password,
+ 'name': self.name,
+ }
+
+ def copy(self):
+ return self.get_class()(**self._make_dict())
+
def close(self):
- self._client.close()
- self._client = False
+ if self._client:
+ self._client.close()
+ self._client = False
def run(self, cmd, stdin=None, stdout=None, stderr=None,
raise_on_error=True, timeout=3600,
@@ -308,7 +350,7 @@ class SSH(object):
timeout=timeout, raise_on_error=False)
stdout.seek(0)
stderr.seek(0)
- return (exit_status, stdout.read(), stderr.read())
+ return exit_status, stdout.read(), stderr.read()
def wait(self, timeout=120, interval=1):
"""Wait for the host will be available via ssh."""
@@ -369,3 +411,72 @@ class SSH(object):
self._put_file_sftp(localpath, remotepath, mode=mode)
except (paramiko.SSHException, socket.error):
self._put_file_shell(localpath, remotepath, mode=mode)
+
+ def provision_tool(self, tool_path, tool_file=None):
+ return provision_tool(self, tool_path, tool_file)
+
+ def put_file_obj(self, file_obj, remotepath, mode=None):
+ client = self._get_client()
+
+ with client.open_sftp() as sftp:
+ sftp.putfo(file_obj, remotepath)
+ if mode is not None:
+ sftp.chmod(remotepath, mode)
+
+
+class AutoConnectSSH(SSH):
+
+ def __init__(self, user, host, port=None, pkey=None,
+ key_filename=None, password=None, name=None, wait=False):
+ super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
+ self._wait = wait
+
+ def _make_dict(self):
+ data = super(AutoConnectSSH, self)._make_dict()
+ data.update({
+ 'wait': self._wait
+ })
+ return data
+
+ def _connect(self):
+ if not self.is_connected:
+ self._get_client()
+ if self._wait:
+ self.wait()
+
+ def drop_connection(self):
+ """ Don't close anything, just force creation of a new client """
+ self._client = False
+
+ def execute(self, cmd, stdin=None, timeout=3600):
+ self._connect()
+ return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
+
+ def run(self, cmd, stdin=None, stdout=None, stderr=None,
+ raise_on_error=True, timeout=3600,
+ keep_stdin_open=False, pty=False):
+ self._connect()
+ return super(AutoConnectSSH, self).run(cmd, stdin, stdout, stderr, raise_on_error,
+ timeout, keep_stdin_open, pty)
+
+ def put(self, files, remote_path=b'.', recursive=False):
+ self._connect()
+ return super(AutoConnectSSH, self).put(files, remote_path, recursive)
+
+ def put_file(self, local_path, remote_path, mode=None):
+ self._connect()
+ return super(AutoConnectSSH, self).put_file(local_path, remote_path, mode)
+
+ def put_file_obj(self, file_obj, remote_path, mode=None):
+ self._connect()
+ return super(AutoConnectSSH, self).put_file_obj(file_obj, remote_path, mode)
+
+ def provision_tool(self, tool_path, tool_file=None):
+ self._connect()
+ return super(AutoConnectSSH, self).provision_tool(tool_path, tool_file)
+
+ @staticmethod
+ def get_class():
+ # must return static class name, anything else refers to the calling class
+ # i.e. the subclass, not the superclass
+ return AutoConnectSSH