aboutsummaryrefslogtreecommitdiffstats
path: root/yardstick/ssh.py
diff options
context:
space:
mode:
Diffstat (limited to 'yardstick/ssh.py')
-rw-r--r--yardstick/ssh.py204
1 files changed, 154 insertions, 50 deletions
diff --git a/yardstick/ssh.py b/yardstick/ssh.py
index 6ddf327f2..6bc6010f7 100644
--- a/yardstick/ssh.py
+++ b/yardstick/ssh.py
@@ -62,15 +62,13 @@ Eventlet:
sshclient = eventlet.import_patched("yardstick.ssh")
"""
-from __future__ import absolute_import
-import os
import io
+import logging
+import os
+import re
import select
import socket
import time
-import re
-
-import logging
import paramiko
from chainmap import ChainMap
@@ -78,9 +76,11 @@ from oslo_utils import encodeutils
from scp import SCPClient
import six
-from yardstick.common.utils import try_int
+from yardstick.common import exceptions
+from yardstick.common.utils import try_int, NON_NONE_DEFAULT, make_dict_from_map
from yardstick.network_services.utils import provision_tool
+LOG = logging.getLogger(__name__)
def convert_key_to_str(key):
if not isinstance(key, (paramiko.RSAKey, paramiko.DSSKey)):
@@ -90,18 +90,11 @@ def convert_key_to_str(key):
return k.getvalue()
-class SSHError(Exception):
- pass
-
-
-class SSHTimeout(SSHError):
- pass
-
-
class SSH(object):
"""Represent ssh connection."""
SSH_PORT = paramiko.config.SSH_PORT
+ DEFAULT_WAIT_TIMEOUT = 120
@staticmethod
def gen_keys(key_filename, bit_count=2048):
@@ -120,6 +113,18 @@ class SSH(object):
# i.e. the subclass, not the superclass
return SSH
+ @classmethod
+ def get_arg_key_map(cls):
+ return {
+ 'user': ('user', NON_NONE_DEFAULT),
+ 'host': ('ip', NON_NONE_DEFAULT),
+ 'port': ('ssh_port', cls.SSH_PORT),
+ 'pkey': ('pkey', None),
+ 'key_filename': ('key_filename', None),
+ 'password': ('password', None),
+ 'name': ('name', None),
+ }
+
def __init__(self, user, host, port=None, pkey=None,
key_filename=None, password=None, name=None):
"""Initialize SSH client.
@@ -137,6 +142,7 @@ class SSH(object):
else:
self.log = logging.getLogger(__name__)
+ self.wait_timeout = self.DEFAULT_WAIT_TIMEOUT
self.user = user
self.host = host
# everybody wants to debug this in the caller, do it here instead
@@ -162,16 +168,9 @@ class SSH(object):
overrides = {}
if defaults is None:
defaults = {}
+
params = ChainMap(overrides, node, defaults)
- return {
- 'user': params['user'],
- 'host': params['ip'],
- 'port': params.get('ssh_port', cls.SSH_PORT),
- 'pkey': params.get('pkey'),
- 'key_filename': params.get('key_filename'),
- 'password': params.get('password'),
- 'name': params.get('name'),
- }
+ return make_dict_from_map(params, cls.get_arg_key_map())
@classmethod
def from_node(cls, node, overrides=None, defaults=None):
@@ -186,7 +185,7 @@ class SSH(object):
return key_class.from_private_key(key)
except paramiko.SSHException as e:
errors.append(e)
- raise SSHError("Invalid pkey: %s" % (errors))
+ raise exceptions.SSHError(error_msg='Invalid pkey: %s' % errors)
@property
def is_connected(self):
@@ -207,10 +206,10 @@ class SSH(object):
return self._client
except Exception as e:
message = ("Exception %(exception_type)s was raised "
- "during connect. Exception value is: %(exception)r")
+ "during connect. Exception value is: %(exception)r" %
+ {"exception": e, "exception_type": type(e)})
self._client = False
- raise SSHError(message % {"exception": e,
- "exception_type": type(e)})
+ raise exceptions.SSHError(error_msg=message)
def _make_dict(self):
return {
@@ -287,7 +286,7 @@ class SSH(object):
while True:
# Block until data can be read/write.
- r, w, e = select.select([session], writes, [session], 1)
+ e = select.select([session], writes, [session], 1)[2]
if session.recv_ready():
data = encodeutils.safe_decode(session.recv(4096), 'utf-8')
@@ -327,11 +326,11 @@ class SSH(object):
break
if timeout and (time.time() - timeout) > start_time:
- args = {"cmd": cmd, "host": self.host}
- raise SSHTimeout("Timeout executing command "
- "'%(cmd)s' on host %(host)s" % args)
+ message = ('Timeout executing command %(cmd)s on host %(host)s'
+ % {"cmd": cmd, "host": self.host})
+ raise exceptions.SSHTimeout(error_msg=message)
if e:
- raise SSHError("Socket error.")
+ raise exceptions.SSHError(error_msg='Socket error')
exit_status = session.recv_exit_status()
if exit_status != 0 and raise_on_error:
@@ -339,15 +338,18 @@ class SSH(object):
details = fmt % {"cmd": cmd, "status": exit_status}
if stderr_data:
details += " Last stderr data: '%s'." % stderr_data
- raise SSHError(details)
+ LOG.critical("PROX ERROR: %s", details)
+ raise exceptions.SSHError(error_msg=details)
return exit_status
- def execute(self, cmd, stdin=None, timeout=3600):
+ def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
"""Execute the specified command on the server.
- :param cmd: Command to be executed.
- :param stdin: Open file to be sent on process stdin.
- :param timeout: Timeout for execution of the command.
+ :param cmd: (str) Command to be executed.
+ :param stdin: (StringIO) Open file to be sent on process stdin.
+ :param timeout: (int) Timeout for execution of the command.
+ :param raise_on_error: (bool) If True, then an SSHError will be raised
+ when non-zero exit code.
:returns: tuple (exit_status, stdout, stderr)
"""
@@ -356,22 +358,26 @@ class SSH(object):
exit_status = self.run(cmd, stderr=stderr,
stdout=stdout, stdin=stdin,
- timeout=timeout, raise_on_error=False)
+ timeout=timeout, raise_on_error=raise_on_error)
stdout.seek(0)
stderr.seek(0)
return exit_status, stdout.read(), stderr.read()
- def wait(self, timeout=120, interval=1):
+ def wait(self, timeout=None, interval=1):
"""Wait for the host will be available via ssh."""
- start_time = time.time()
+ if timeout is None:
+ timeout = self.wait_timeout
+
+ end_time = time.time() + timeout
while True:
try:
return self.execute("uname")
- except (socket.error, SSHError) as e:
+ except (socket.error, exceptions.SSHError) as e:
self.log.debug("Ssh is still unavailable: %r", e)
time.sleep(interval)
- if time.time() > (start_time + timeout):
- raise SSHTimeout("Timeout waiting for '%s'", self.host)
+ if time.time() > end_time:
+ raise exceptions.SSHTimeout(
+ error_msg='Timeout waiting for "%s"' % self.host)
def put(self, files, remote_path=b'.', recursive=False):
client = self._get_client()
@@ -444,35 +450,133 @@ class SSH(object):
with client.open_sftp() as sftp:
sftp.getfo(remotepath, file_obj)
+ def interactive_terminal_open(self, time_out=45):
+ """Open interactive terminal on a SSH channel.
+
+ :param time_out: Timeout in seconds.
+ :returns: SSH channel with opened terminal.
+
+ .. warning:: Interruptingcow is used here, and it uses
+ signal(SIGALRM) to let the operating system interrupt program
+ execution. This has the following limitations: Python signal
+ handlers only apply to the main thread, so you cannot use this
+ from other threads. You must not use this in a program that
+ uses SIGALRM itself (this includes certain profilers)
+ """
+ chan = self._get_client().get_transport().open_session()
+ chan.get_pty()
+ chan.invoke_shell()
+ chan.settimeout(int(time_out))
+ chan.set_combine_stderr(True)
+
+ buf = ''
+ while not buf.endswith((":~# ", ":~$ ", "~]$ ", "~]# ")):
+ try:
+ chunk = chan.recv(10 * 1024 * 1024)
+ if not chunk:
+ break
+ buf += chunk
+ if chan.exit_status_ready():
+ self.log.error('Channel exit status ready')
+ break
+ except socket.timeout:
+ raise exceptions.SSHTimeout(error_msg='Socket timeout: %s' % buf)
+ return chan
+
+ def interactive_terminal_exec_command(self, chan, cmd, prompt):
+ """Execute command on interactive terminal.
+
+ interactive_terminal_open() method has to be called first!
+
+ :param chan: SSH channel with opened terminal.
+ :param cmd: Command to be executed.
+ :param prompt: Command prompt, sequence of characters used to
+ indicate readiness to accept commands.
+ :returns: Command output.
+
+ .. warning:: Interruptingcow is used here, and it uses
+ signal(SIGALRM) to let the operating system interrupt program
+ execution. This has the following limitations: Python signal
+ handlers only apply to the main thread, so you cannot use this
+ from other threads. You must not use this in a program that
+ uses SIGALRM itself (this includes certain profilers)
+ """
+ chan.sendall('{c}\n'.format(c=cmd))
+ buf = ''
+ while not buf.endswith(prompt):
+ try:
+ chunk = chan.recv(10 * 1024 * 1024)
+ if not chunk:
+ break
+ buf += chunk
+ if chan.exit_status_ready():
+ self.log.error('Channel exit status ready')
+ break
+ except socket.timeout:
+ message = ("Socket timeout during execution of command: "
+ "%(cmd)s\nBuffer content:\n%(buf)s" % {"cmd": cmd,
+ "buf": buf})
+ raise exceptions.SSHTimeout(error_msg=message)
+ tmp = buf.replace(cmd.replace('\n', ''), '')
+ for item in prompt:
+ tmp.replace(item, '')
+ return tmp
+
+ @staticmethod
+ def interactive_terminal_close(chan):
+ """Close interactive terminal SSH channel.
+
+ :param: chan: SSH channel to be closed.
+ """
+ chan.close()
+
class AutoConnectSSH(SSH):
+ @classmethod
+ def get_arg_key_map(cls):
+ arg_key_map = super(AutoConnectSSH, cls).get_arg_key_map()
+ arg_key_map['wait'] = ('wait', True)
+ return arg_key_map
+
# always wait or we will get OpenStack SSH errors
def __init__(self, user, host, port=None, pkey=None,
key_filename=None, password=None, name=None, wait=True):
super(AutoConnectSSH, self).__init__(user, host, port, pkey, key_filename, password, name)
- self._wait = wait
+ if wait and wait is not True:
+ self.wait_timeout = int(wait)
def _make_dict(self):
data = super(AutoConnectSSH, self)._make_dict()
data.update({
- 'wait': self._wait
+ 'wait': self.wait_timeout
})
return data
def _connect(self):
if not self.is_connected:
- self._get_client()
- if self._wait:
- self.wait()
+ interval = 1
+ timeout = self.wait_timeout
+
+ end_time = time.time() + timeout
+ while True:
+ try:
+ return self._get_client()
+ except (socket.error, exceptions.SSHError) as e:
+ self.log.debug("Ssh is still unavailable: %r", e)
+ time.sleep(interval)
+ if time.time() > end_time:
+ raise exceptions.SSHTimeout(
+ error_msg='Timeout waiting for "%s"' % self.host)
def drop_connection(self):
""" Don't close anything, just force creation of a new client """
self._client = False
- def execute(self, cmd, stdin=None, timeout=3600):
+ def execute(self, cmd, stdin=None, timeout=3600, raise_on_error=False):
self._connect()
- return super(AutoConnectSSH, self).execute(cmd, stdin, timeout)
+ return super(AutoConnectSSH, self).execute(cmd, stdin, timeout,
+ raise_on_error)
def run(self, cmd, stdin=None, stdout=None, stderr=None,
raise_on_error=True, timeout=3600,