###############################################################################
# Copyright (c) 2017 Koren Lev (Cisco Systems), Yaron Yogev (Cisco Systems)   #
# and others                                                                  #
#                                                                             #
# All rights reserved. This program and the accompanying materials            #
# are made available under the terms of the Apache License, Version 2.0       #
# which accompanies this distribution, and is available at                    #
# http://www.apache.org/licenses/LICENSE-2.0                                  #
###############################################################################
import os

import paramiko

from utils.binary_converter import BinaryConverter


class SshConnection(BinaryConverter):
    connections = {}

    max_call_count_per_con = 100
    timeout = 15  # timeout for exec in seconds

    DEFAULT_PORT = 22

    def __init__(self, _host: str, _user: str, _pwd: str=None, _key: str = None,
                 _port: int = None,  _call_count_limit: int=None,
                 for_sftp: bool = False):
        super().__init__()
        self.host = _host
        self.ssh_client = None
        self.ftp = None
        self.for_sftp = for_sftp
        self.key = _key
        self.port = _port
        self.user = _user
        self.pwd = _pwd
        self.check_definitions()
        self.fetched_host_details = False
        self.call_count = 0
        self.call_count_limit = 0 if for_sftp \
            else (SshConnection.max_call_count_per_con
                  if _call_count_limit is None else _call_count_limit)
        self.connections[self.get_connection_key(_host, for_sftp)] = self

    def check_definitions(self):
        if not self.host:
            raise ValueError('Missing definition of host for CLI access')
        if not self.user:
            raise ValueError('Missing definition of user ' +
                             'for CLI access to host {}'.format(self.host))
        if self.key and not os.path.exists(self.key):
            raise ValueError('Key file not found: ' + self.key)
        if not self.key and not self.pwd:
            raise ValueError('Must specify key or password ' +
                             'for CLI access to host {}'.format(self.host))

    @staticmethod
    def get_ssh(host, _for_sftp=False):
        return SshConnection.get_connection(host, for_sftp=_for_sftp)

    @staticmethod
    def get_connection_key(host, for_sftp=False):
        key = ('sftp-' if for_sftp else '') + host
        return key

    @staticmethod
    def get_connection(host, for_sftp=False):
        key = SshConnection.get_connection_key(host, for_sftp)
        return SshConnection.connections.get(key)

    def disconnect(self):
        if self.ssh_client:
            self.ssh_client.close()

    @staticmethod
    def disconnect_all():
        for ssh in SshConnection.connections.values():
            ssh.disconnect()
        SshConnection.connections = {}

    def get_host(self):
        return self.host

    def get_user(self):
        return self.user

    def set_call_limit(self, _limit: int):
        self.call_count_limit = _limit

    def connect(self, reconnect=False) -> bool:
        connection = self.get_connection(self.host, self.for_sftp)
        if connection and connection.ssh_client:
            self.ssh_client = connection.ssh_client
            if reconnect:
                self.log.info("SshConnection: " +
                              "****** forcing reconnect: %s ******",
                              self.host)
            elif self.call_count >= self.call_count_limit > 0:
                self.log.info("SshConnection: ****** reconnecting: %s, " +
                              "due to call count: %s ******",
                              self.host, self.call_count)
            else:
                return True
            connection.close()
            self.ssh_client = None
        self.ssh_client = paramiko.SSHClient()
        connection_key = SshConnection.get_connection_key(self.host,
                                                          self.for_sftp)
        SshConnection.connections[connection_key] = self
        self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        if self.key:
            k = paramiko.RSAKey.from_private_key_file(self.key)
            self.ssh_client.connect(hostname=self.host,
                                    username=self.user,
                                    pkey=k,
                                    port=self.port if self.port is not None
                                    else self.DEFAULT_PORT,
                                    password=self.pwd, timeout=30)
        else:
            try:
                port = self.port if self.port is not None else self.DEFAULT_PORT
                self.ssh_client.connect(self.host,
                                        username=self.user,
                                        password=self.pwd,
                                        port=port,
                                        timeout=30)
            except paramiko.ssh_exception.AuthenticationException:
                self.log.error('Failed SSH connect to host {}, port={}'
                               .format(self.host, port))
                self.ssh_client = None
        self.call_count = 0
        return self.ssh_client is not None

    def exec(self, cmd):
        if not self.connect():
            return ''
        self.call_count += 1
        self.log.debug("call count: %s, running call:\n%s\n",
                       str(self.call_count), cmd)
        stdin, stdout, stderr = \
            self.ssh_client.exec_command(cmd, timeout=self.timeout)
        stdin.close()
        err = self.binary2str(stderr.read())
        if err:
            # ignore messages about loading plugin
            err_lines = [l for l in err.splitlines()
                         if 'Loaded plugin: ' not in l]
            if err_lines:
                self.log.error("CLI access: \n" +
                               "Host: {}\nCommand: {}\nError: {}\n".
                               format(self.host, cmd, err))
                stderr.close()
                stdout.close()
                return ""
        ret = self.binary2str(stdout.read())
        stderr.close()
        stdout.close()
        return ret

    def copy_file(self, local_path, remote_path, mode=None):
        if not self.connect():
            return
        if not self.ftp:
            self.ftp = self.ssh_client.open_sftp()
        try:
            self.ftp.put(local_path, remote_path)
        except IOError as e:
            self.log.error('SFTP copy_file failed to copy file: ' +
                           'local: ' + local_path +
                           ', remote host: ' + self.host +
                           ', error: ' + str(e))
            return str(e)
        try:
            remote_file = self.ftp.file(remote_path, 'a+')
        except IOError as e:
            self.log.error('SFTP copy_file failed to open file after put(): ' +
                           'local: ' + local_path +
                           ', remote host: ' + self.host +
                           ', error: ' + str(e))
            return str(e)
        try:
            if mode:
                remote_file.chmod(mode)
        except IOError as e:
            self.log.error('SFTP copy_file failed to chmod file: ' +
                           'local: ' + local_path +
                           ', remote host: ' + self.host +
                           ', port: ' + self.port +
                           ', error: ' + str(e))
            return str(e)
        self.log.info('SFTP copy_file success: '
                      'host={},port={},{} -> {}'.format(
            str(self.host), str(self.port), str(local_path), str(remote_path)))
        return ''

    def copy_file_from_remote(self, remote_path, local_path):
        if not self.connect():
            return
        if not self.ftp:
            self.ftp = self.ssh_client.open_sftp()
        try:
            self.ftp.get(remote_path, local_path)
        except IOError as e:
            self.log.error('SFTP copy_file_from_remote failed to copy file: '
                           'remote host: {}, '
                           'remote_path: {}, local: {}, error: {}'
                           .format(self.host, remote_path, local_path, str(e)))
            return str(e)
        self.log.info('SFTP copy_file_from_remote success: host={},{} -> {}'.
                      format(self.host, remote_path, local_path))
        return ''

    def is_gateway_host(self, host):
        return True