# Copyright 2013: Mirantis Inc. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. """High level ssh library. Usage examples: Execute command and get output: ssh = sshclient.SSH('root', 'example.com', port=33) status, stdout, stderr = ssh.execute('ps ax') if status: raise Exception('Command failed with non-zero status.') print stdout.splitlines() Execute command with huge output: class PseudoFile(object): def write(chunk): if 'error' in chunk: email_admin(chunk) ssh = sshclient.SSH('root', 'example.com') ssh.run('tail -f /var/log/syslog', stdout=PseudoFile(), timeout=False) Execute local script on remote side: ssh = sshclient.SSH('user', 'example.com') status, out, err = ssh.execute('/bin/sh -s arg1 arg2', stdin=open('~/myscript.sh', 'r')) Upload file: ssh = sshclient.SSH('user', 'example.com') ssh.run('cat > ~/upload/file.gz', stdin=open('/store/file.gz', 'rb')) Eventlet: eventlet.monkey_patch(select=True, time=True) or eventlet.monkey_patch() or sshclient = eventlet.import_patched("opentstack.common.sshclient") """ import re import select import shlex import socket import StringIO import subprocess import sys import threading import time from log import LOG import paramiko # from rally.openstack.common.gettextutils import _ class ConnectionError(Exception): pass class Connection(object): ''' A base connection class. Not intended to be constructed. ''' def __init__(self): self.distro_id = None self.distro_id_like = None self.distro_version = None self.__get_distro() def close(self): pass def execute(self, cmd, stdin=None, timeout=3600): pass def __extract_property(self, name, input_str): expr = name + r'="?([\w\.]*)"?' match = re.search(expr, input_str) if match: return match.group(1) return 'Unknown' # Get the linux distro def __get_distro(self): '''cat /etc/*-release | grep ID Ubuntu: DISTRIB_ID=Ubuntu ID=ubuntu ID_LIKE=debian VERSION_ID="14.04" RHEL: ID="rhel" ID_LIKE="fedora" VERSION_ID="7.0" ''' distro_cmd = "grep ID /etc/*-release" (status, distro_out, _) = self.execute(distro_cmd) if status: distro_out = '' self.distro_id = self.__extract_property('ID', distro_out) self.distro_id_like = self.__extract_property('ID_LIKE', distro_out) self.distro_version = self.__extract_property('VERSION_ID', distro_out) def pidof(self, proc_name): ''' Return a list containing the pids of all processes of a given name the list is empty if there is no pid ''' # the path update is necessary for RHEL cmd = "PATH=$PATH:/usr/sbin pidof " + proc_name (status, cmd_output, _) = self.execute(cmd) if status: return [] cmd_output = cmd_output.strip() result = cmd_output.split() return result # kill pids in the given list of pids def kill_proc(self, pid_list): cmd = "kill -9 " + ' '.join(pid_list) self.execute(cmd) # check stats for a given path def stat(self, path): (status, cmd_output, _) = self.execute('stat ' + path) if status: return None return cmd_output def ping_check(self, target_ip, ping_count=2, pass_threshold=80): '''helper function to ping from one host to an IP address, for a given count and pass_threshold; Steps: ssh to the host and then ping to the target IP then match the output and verify that the loss% is less than the pass_threshold% Return 1 if the criteria passes Return 0, if it fails ''' cmd = "ping -c " + str(ping_count) + " " + str(target_ip) (_, cmd_output, _) = self.execute(cmd) match = re.search(r'(\d*)% packet loss', cmd_output) pkt_loss = match.group(1) if int(pkt_loss) < int(pass_threshold): return 1 else: LOG.error('Ping to %s failed: %s', target_ip, cmd_output) return 0 def read_remote_file(self, from_path): ''' Read a remote file and save it to a buffer. ''' cmd = "cat " + from_path (status, cmd_output, _) = self.execute(cmd) if status: return None return cmd_output def get_host_os_version(self): ''' Identify the host distribution/relase. ''' os_release_file = "/etc/os-release" sys_release_file = "/etc/system-release" name = "" version = "" if self.stat(os_release_file): data = self.read_remote_file(os_release_file) if data is None: LOG.error("Failed to read file %s", os_release_file) return None for line in data.splitlines(): mobj = re.match(r'NAME=(.*)', line) if mobj: name = mobj.group(1).strip("\"") mobj = re.match(r'VERSION_ID=(.*)', line) if mobj: version = mobj.group(1).strip("\"") os_name = name + " " + version return os_name if self.stat(sys_release_file): data = self.read_remote_file(sys_release_file) if data is None: LOG.error("Failed to read file %s", sys_release_file) return None for line in data.splitlines(): mobj = re.match(r'Red Hat.*', line) if mobj: return mobj.group(0) return None def check_rpm_package_installed(self, rpm_pkg): ''' Given a host and a package name, check if it is installed on the system. ''' check_pkg_cmd = "rpm -qa | grep " + rpm_pkg (status, cmd_output, _) = self.execute(check_pkg_cmd) if status: return None pkg_pattern = ".*" + rpm_pkg + ".*" rpm_pattern = re.compile(pkg_pattern, re.IGNORECASE) for line in cmd_output.splitlines(): mobj = rpm_pattern.match(line) if mobj: return mobj.group(0) LOG.info("%s pkg installed ", rpm_pkg) return None def get_openstack_release(self, ver_str): ''' Get the release series name from the package version Refer to here for release tables: https://wiki.openstack.org/wiki/Releases ''' ver_table = {"2015.1": "Kilo", "2014.2": "Juno", "2014.1": "Icehouse", "2013.2": "Havana", "2013.1": "Grizzly", "2012.2": "Folsom", "2012.1": "Essex", "2011.3": "Diablo", "2011.2": "Cactus", "2011.1": "Bexar", "2010.1": "Austin"} ver_prefix = re.search(r"20\d\d\.\d", ver_str).group(0) if ver_prefix in ver_table: return ver_table[ver_prefix] else: return "Unknown" def check_openstack_version(self): ''' Identify the openstack version running on the controller. ''' nova_cmd = "nova-manage --version" (status, _, err_output) = self.execute(nova_cmd) if status: return "Unknown" ver_str = err_output.strip() release_str = self.get_openstack_release(err_output) return release_str + " (" + ver_str + ")" def get_cpu_info(self): ''' Get the CPU info of the controller. Note: Here we are assuming the controller node has the exact hardware as the compute nodes. ''' cmd = 'cat /proc/cpuinfo | grep -m1 "model name"' (status, std_output, _) = self.execute(cmd) if status: return "Unknown" model_name = re.search(r":\s(.*)", std_output).group(1) cmd = 'cat /proc/cpuinfo | grep "model name" | wc -l' (status, std_output, _) = self.execute(cmd) if status: return "Unknown" cores = std_output.strip() return (cores + " * " + model_name) def get_nic_name(self, agent_type, encap, internal_iface_dict): ''' Get the NIC info of the controller. Note: Here we are assuming the controller node has the exact hardware as the compute nodes. ''' # The internal_ifac_dict is a dictionary contains the mapping between # hostname and the internal interface name like below: # {u'hh23-4': u'eth1', u'hh23-5': u'eth1', u'hh23-6': u'eth1'} cmd = "hostname" (status, std_output, _) = self.execute(cmd) if status: return "Unknown" hostname = std_output.strip() if hostname in internal_iface_dict: iface = internal_iface_dict[hostname] else: return "Unknown" # Figure out which interface is for internal traffic if 'Linux bridge' in agent_type: ifname = iface elif 'Open vSwitch' in agent_type: if encap == 'vlan': # [root@hh23-10 ~]# ovs-vsctl list-ports br-inst # eth1 # phy-br-inst cmd = 'ovs-vsctl list-ports ' + \ iface + ' | grep -E "^[^phy].*"' (status, std_output, _) = self.execute(cmd) if status: return "Unknown" ifname = std_output.strip() elif encap == 'vxlan' or encap == 'gre': # This is complicated. We need to first get the local IP address on # br-tun, then do a reverse lookup to get the physical interface. # # [root@hh23-4 ~]# ip addr show to "23.23.2.14" # 3: eth1: mtu 1500 qdisc mq state UP qlen 1000 # inet 23.23.2.14/24 brd 23.23.2.255 scope global eth1 # valid_lft forever preferred_lft forever cmd = "ip addr show to " + iface + " | awk -F: '{print $2}'" (status, std_output, _) = self.execute(cmd) if status: return "Unknown" ifname = std_output.strip() else: return "Unknown" cmd = 'ethtool -i ' + ifname + ' | grep bus-info' (status, std_output, _) = self.execute(cmd) if status: return "Unknown" bus_info = re.search(r":\s(.*)", std_output).group(1) cmd = 'lspci -s ' + bus_info (status, std_output, _) = self.execute(cmd) if status: return "Unknown" nic_name = re.search( r"Ethernet controller:\s(.*)", std_output).group(1) return (nic_name) def get_l2agent_version(self, agent_type): ''' Get the L2 agent version of the controller. Note: Here we are assuming the controller node has the exact hardware as the compute nodes. ''' if 'Linux bridge' in agent_type: cmd = "brctl --version | awk -F',' '{print $2}'" ver_string = "Linux Bridge " elif 'Open vSwitch' in agent_type: cmd = "ovs-vsctl --version | awk -F')' '{print $2}'" ver_string = "OVS " else: return "Unknown" (status, std_output, _) = self.execute(cmd) if status: return "Unknown" return ver_string + std_output.strip() class SSHError(Exception): pass class SSHTimeout(SSHError): pass # Check IPv4 address syntax - not completely fool proof but will catch # some invalid formats def is_ipv4(address): try: socket.inet_aton(address) except socket.error: return False return True class SSHAccess(object): ''' A class to contain all the information needed to access a host (native or virtual) using SSH ''' def __init__(self, arg_value=None): ''' decode user@host[:pwd] 'hugo@1.1.1.1:secret' -> ('hugo', '1.1.1.1', 'secret', None) 'huggy@2.2.2.2' -> ('huggy', '2.2.2.2', None, None) None ->(None, None, None, None) Examples of fatal errors (will call exit): 'hutch@q.1.1.1' (invalid IP) '@3.3.3.3' (missing username) 'hiro@' or 'buggy' (missing host IP) The error field will be None in case of success or will contain a string describing the error ''' self.username = None self.host = None self.password = None # name of the file that contains the private key self.private_key_file = None # this is the private key itself (a long string starting with # -----BEGIN RSA PRIVATE KEY----- # used when the private key is not saved in any file self.private_key = None self.public_key_file = None self.port = 22 self.error = None if not arg_value: return match = re.search(r'^([^@]+)@([0-9\.]+):?(.*)$', arg_value) if not match: self.error = 'Invalid argument: ' + arg_value return if not is_ipv4(match.group(2)): self.error = 'Invalid IPv4 address ' + match.group(2) return (self.username, self.host, self.password) = match.groups() def copy_from(self, ssh_access): self.username = ssh_access.username self.host = ssh_access.host self.port = ssh_access.port self.password = ssh_access.password self.private_key = ssh_access.private_key self.public_key_file = ssh_access.public_key_file self.private_key_file = ssh_access.private_key_file class SSH(Connection): """Represent ssh connection.""" def __init__(self, ssh_access, connect_timeout=60, connect_retry_count=30, connect_retry_wait_sec=2): """Initialize SSH client. :param user: ssh username :param host: hostname or ip address of remote ssh server :param port: remote ssh port :param pkey: RSA or DSS private key string or file object :param key_filename: private key filename :param password: password :param connect_timeout: timeout when connecting ssh :param connect_retry_count: how many times to retry connecting :param connect_retry_wait_sec: seconds to wait between retries """ self.ssh_access = ssh_access if ssh_access.private_key: self.pkey = self._get_pkey(ssh_access.private_key) else: self.pkey = None self._client = False self.connect_timeout = connect_timeout self.connect_retry_count = connect_retry_count self.connect_retry_wait_sec = connect_retry_wait_sec super(SSH, self).__init__() def _get_pkey(self, key): '''Get the binary form of the private key from the text form ''' if isinstance(key, basestring): key = StringIO.StringIO(key) errors = [] for key_class in (paramiko.rsakey.RSAKey, paramiko.dsskey.DSSKey): try: return key_class.from_private_key(key) except paramiko.SSHException as exc: errors.append(exc) raise SSHError('Invalid pkey: %s' % (errors)) def _is_active(self): if self._client: try: transport = self._client.get_transport() session = transport.open_session() session.close() return True except Exception: return False else: return False def _get_client(self, force=False): if not force and self._is_active(): return self._client if self._client: LOG.info('Re-establishing ssh connection with %s' % (self.ssh_access.host)) self._client.close() self._client = paramiko.SSHClient() self._client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) for _ in range(self.connect_retry_count): try: self._client.connect(self.ssh_access.host, username=self.ssh_access.username, port=self.ssh_access.port, pkey=self.pkey, key_filename=self.ssh_access.private_key_file, password=self.ssh_access.password, timeout=self.connect_timeout) self._client.get_transport().set_keepalive(5) return self._client except (paramiko.AuthenticationException, paramiko.BadHostKeyException, paramiko.SSHException, socket.error, Exception): time.sleep(self.connect_retry_wait_sec) self._client = None msg = '[%s] SSH Connection failed after %s attempts' % (self.ssh_access.host, self.connect_retry_count) raise SSHError(msg) def _get_session(self): client = self._get_client() for _ in range(self.connect_retry_count): try: transport = client.get_transport() session = transport.open_session() return session except Exception: client = self._get_client(force=True) return None def close(self): super(SSH, self).close() if self._client: self._client.close() self._client = False def run(self, cmd, stdin=None, stdout=None, stderr=None, raise_on_error=True, timeout=3600, sudo=False): """Execute specified command on the server. :param cmd: Command to be executed. :param stdin: Open file or string to pass to stdin. :param stdout: Open file to connect to stdout. :param stderr: Open file to connect to stderr. :param raise_on_error: If False then exit code will be return. If True then exception will be raized if non-zero code. :param timeout: Timeout in seconds for command execution. Default 1 hour. No timeout if set to 0. :param sudo: Executes command as sudo with default password """ if isinstance(stdin, basestring): stdin = StringIO.StringIO(stdin) return self._run(cmd, stdin=stdin, stdout=stdout, stderr=stderr, raise_on_error=raise_on_error, timeout=timeout, sudo=sudo) def _run(self, cmd, stdin=None, stdout=None, stderr=None, raise_on_error=True, timeout=3600, sudo=False): session = self._get_session() if session is None: raise SSHError('Unable to open session to ssh connection') if sudo: cmd = "echo " + self.ssh_access.password + " | sudo -S -p '' " + cmd session.get_pty() session.exec_command(cmd) start_time = time.time() data_to_send = '' stderr_data = None # If we have data to be sent to stdin then `select' should also # check for stdin availability. if stdin and not stdin.closed: writes = [session] else: writes = [] while True: # Block until data can be read/write. select.select([session], writes, [session], 1) if session.recv_ready(): data = session.recv(4096) if stdout is not None: stdout.write(data) continue if session.recv_stderr_ready(): stderr_data = session.recv_stderr(4096) if stderr is not None: stderr.write(stderr_data) continue if session.send_ready(): if stdin is not None and not stdin.closed: if not data_to_send: data_to_send = stdin.read(4096) if not data_to_send: stdin.close() session.shutdown_write() writes = [] continue sent_bytes = session.send(data_to_send) data_to_send = data_to_send[sent_bytes:] if session.exit_status_ready(): break if timeout and (time.time() - timeout) > start_time: args = {'cmd': cmd, 'host': self.ssh_access.host} raise SSHTimeout(('Timeout executing command ' '"%(cmd)s" on host %(host)s') % args) # if e: # raise SSHError('Socket error.') exit_status = session.recv_exit_status() if 0 != exit_status and raise_on_error: fmt = ('Command "%(cmd)s" failed with exit_status %(status)d.') details = fmt % {'cmd': cmd, 'status': exit_status} if stderr_data: details += (' Last stderr data: "%s".') % stderr_data raise SSHError(details) return exit_status def execute(self, cmd, stdin=None, timeout=3600, sudo=False): """Execute the specified command on the server. :param cmd: Command to be executed. :param stdin: Open file to be sent on process stdin. :param timeout: Timeout for execution of the command. Return tuple (exit_status, stdout, stderr) """ stdout = StringIO.StringIO() stderr = StringIO.StringIO() exit_status = self.run(cmd, stderr=stderr, stdout=stdout, stdin=stdin, timeout=timeout, raise_on_error=False, sudo=sudo) stdout.seek(0) stderr.seek(0) return (exit_status, stdout.read(), stderr.read()) def wait(self, timeout=120, interval=1): """Wait for the host will be available via ssh.""" start_time = time.time() while True: try: return self.execute('uname') except (socket.error, SSHError): time.sleep(interval) if time.time() > (start_time + timeout): raise SSHTimeout( ('Timeout waiting for "%s"') % self.ssh_access.host) class SubprocessTimeout(Exception): pass class Subprocess(Connection): """Represent subprocess connection.""" def execute(self, cmd, stdin=None, timeout=3600): process = subprocess.Popen(shlex.split(cmd), stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=True) timer = threading.Timer(timeout, process.kill) stdout, stderr = process.communicate(input=stdin) status = process.wait() if timer.is_alive(): timer.cancel() raise SubprocessTimeout('Timeout executing command "%(cmd)s"') return (status, stdout, stderr) ################################################## # Only invoke the module directly for test purposes. Should be # invoked from pns script. ################################################## def main(): # As argument pass the SSH access string, e.g. "localadmin@1.1.1.1:secret" test_ssh = SSH(SSHAccess(sys.argv[1])) print 'ID=' + test_ssh.distro_id print 'ID_LIKE=' + test_ssh.distro_id_like print 'VERSION_ID=' + test_ssh.distro_version # ssh.wait() # print ssh.pidof('bash') # print ssh.stat('/tmp') print test_ssh.check_openstack_version() print test_ssh.get_cpu_info() print test_ssh.get_l2agent_version("Open vSwitch agent") if __name__ == "__main__": main()