aboutsummaryrefslogtreecommitdiffstats
path: root/app/utils/ssh_connection.py
diff options
context:
space:
mode:
Diffstat (limited to 'app/utils/ssh_connection.py')
-rw-r--r--app/utils/ssh_connection.py217
1 files changed, 217 insertions, 0 deletions
diff --git a/app/utils/ssh_connection.py b/app/utils/ssh_connection.py
new file mode 100644
index 0000000..0fa197a
--- /dev/null
+++ b/app/utils/ssh_connection.py
@@ -0,0 +1,217 @@
+###############################################################################
+# Copyright (c) 2017 Koren Lev (Cisco Systems), Yaron Yogev (Cisco Systems) #
+# and others #
+# #
+# All rights reserved. This program and the accompanying materials #
+# are made available under the terms of the Apache License, Version 2.0 #
+# which accompanies this distribution, and is available at #
+# http://www.apache.org/licenses/LICENSE-2.0 #
+###############################################################################
+import os
+
+import paramiko
+
+from utils.binary_converter import BinaryConverter
+
+
+class SshConnection(BinaryConverter):
+ config = None
+ ssh = None
+ connections = {}
+ cli_connections = {}
+ sftp_connections = {}
+
+ max_call_count_per_con = 100
+ timeout = 15 # timeout for exec in seconds
+
+ DEFAULT_PORT = 22
+
+ def __init__(self, _host: str, _user: str, _pwd: str=None, _key: str = None,
+ _port: int = None, _call_count_limit: int=None,
+ for_sftp: bool = False):
+ super().__init__()
+ self.host = _host
+ self.ssh = None
+ self.ftp = None
+ self.for_sftp = for_sftp
+ self.key = _key
+ self.port = _port
+ self.user = _user
+ self.pwd = _pwd
+ self.check_definitions()
+ self.fetched_host_details = False
+ self.call_count = 0
+ self.call_count_limit = 0 if for_sftp \
+ else (SshConnection.max_call_count_per_con
+ if _call_count_limit is None else _call_count_limit)
+ if for_sftp:
+ self.sftp_connections[_host] = self
+ else:
+ self.cli_connections[_host] = self
+
+ def check_definitions(self):
+ if not self.host:
+ raise ValueError('Missing definition of host for CLI access')
+ if not self.user:
+ raise ValueError('Missing definition of user ' +
+ 'for CLI access to host {}'.format(self.host))
+ if self.key and not os.path.exists(self.key):
+ raise ValueError('Key file not found: ' + self.key)
+ if not self.key and not self.pwd:
+ raise ValueError('Must specify key or password ' +
+ 'for CLI access to host {}'.format(self.host))
+
+ @staticmethod
+ def get_ssh(host, for_sftp=False):
+ if for_sftp:
+ return SshConnection.cli_connections.get(host)
+ return SshConnection.sftp_connections.get(host)
+
+ @staticmethod
+ def get_connection(host, for_sftp=False):
+ key = ('sftp-' if for_sftp else '') + host
+ return SshConnection.connections.get(key)
+
+ def disconnect(self):
+ if self.ssh:
+ self.ssh.close()
+
+ @staticmethod
+ def disconnect_all():
+ for ssh in SshConnection.cli_connections.values():
+ ssh.disconnect()
+ SshConnection.cli_connections = {}
+ for ssh in SshConnection.sftp_connections.values():
+ ssh.disconnect()
+ SshConnection.sftp_connections = {}
+
+ def get_host(self):
+ return self.host
+
+ def get_user(self):
+ return self.user
+
+ def set_call_limit(self, _limit: int):
+ self.call_count_limit = _limit
+
+ def connect(self, reconnect=False) -> bool:
+ connection = self.get_connection(self.host, self.for_sftp)
+ if connection:
+ self.ssh = connection
+ if reconnect:
+ self.log.info("SshConnection: " +
+ "****** forcing reconnect: %s ******",
+ self.host)
+ elif self.call_count >= self.call_count_limit > 0:
+ self.log.info("SshConnection: ****** reconnecting: %s, " +
+ "due to call count: %s ******",
+ self.host, self.call_count)
+ else:
+ return True
+ connection.close()
+ self.ssh = None
+ self.ssh = paramiko.SSHClient()
+ connection_key = ('sftp-' if self.for_sftp else '') + self.host
+ SshConnection.connections[connection_key] = self.ssh
+ self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ if self.key:
+ k = paramiko.RSAKey.from_private_key_file(self.key)
+ self.ssh.connect(hostname=self.host, username=self.user, pkey=k,
+ port=self.port if self.port is not None
+ else self.DEFAULT_PORT,
+ password=self.pwd, timeout=30)
+ else:
+ try:
+ port = self.port if self.port is not None else self.DEFAULT_PORT
+ self.ssh.connect(self.host,
+ username=self.user,
+ password=self.pwd,
+ port=port,
+ timeout=30)
+ except paramiko.ssh_exception.AuthenticationException:
+ self.log.error('Failed SSH connect to host {}, port={}'
+ .format(self.host, port))
+ self.ssh = None
+ self.call_count = 0
+ return self.ssh is not None
+
+ def exec(self, cmd):
+ if not self.connect():
+ return ''
+ self.call_count += 1
+ self.log.debug("call count: %s, running call:\n%s\n",
+ str(self.call_count), cmd)
+ stdin, stdout, stderr = self.ssh.exec_command(cmd, timeout=self.timeout)
+ stdin.close()
+ err = self.binary2str(stderr.read())
+ if err:
+ # ignore messages about loading plugin
+ err_lines = [l for l in err.splitlines()
+ if 'Loaded plugin: ' not in l]
+ if err_lines:
+ self.log.error("CLI access: \n" +
+ "Host: {}\nCommand: {}\nError: {}\n".
+ format(self.host, cmd, err))
+ stderr.close()
+ stdout.close()
+ return ""
+ ret = self.binary2str(stdout.read())
+ stderr.close()
+ stdout.close()
+ return ret
+
+ def copy_file(self, local_path, remote_path, mode=None):
+ if not self.connect():
+ return
+ if not self.ftp:
+ self.ftp = self.ssh.open_sftp()
+ try:
+ self.ftp.put(local_path, remote_path)
+ except IOError as e:
+ self.log.error('SFTP copy_file failed to copy file: ' +
+ 'local: ' + local_path +
+ ', remote host: ' + self.host +
+ ', error: ' + str(e))
+ return str(e)
+ try:
+ remote_file = self.ftp.file(remote_path, 'a+')
+ except IOError as e:
+ self.log.error('SFTP copy_file failed to open file after put(): ' +
+ 'local: ' + local_path +
+ ', remote host: ' + self.host +
+ ', error: ' + str(e))
+ return str(e)
+ try:
+ if mode:
+ remote_file.chmod(mode)
+ except IOError as e:
+ self.log.error('SFTP copy_file failed to chmod file: ' +
+ 'local: ' + local_path +
+ ', remote host: ' + self.host +
+ ', port: ' + self.port +
+ ', error: ' + str(e))
+ return str(e)
+ self.log.info('SFTP copy_file success: '
+ 'host={},port={},{} -> {}'.format(
+ str(self.host), str(self.port), str(local_path), str(remote_path)))
+ return ''
+
+ def copy_file_from_remote(self, remote_path, local_path):
+ if not self.connect():
+ return
+ if not self.ftp:
+ self.ftp = self.ssh.open_sftp()
+ try:
+ self.ftp.get(remote_path, local_path)
+ except IOError as e:
+ self.log.error('SFTP copy_file_from_remote failed to copy file: '
+ 'remote host: {}, '
+ 'remote_path: {}, local: {}, error: {}'
+ .format(self.host, remote_path, local_path, str(e)))
+ return str(e)
+ self.log.info('SFTP copy_file_from_remote success: host={},{} -> {}'.
+ format(self.host, remote_path, local_path))
+ return ''
+
+ def is_gateway_host(self, host):
+ return True