diff options
Diffstat (limited to 'compass-tasks-base/utils/util.py')
-rw-r--r-- | compass-tasks-base/utils/util.py | 395 |
1 files changed, 395 insertions, 0 deletions
diff --git a/compass-tasks-base/utils/util.py b/compass-tasks-base/utils/util.py new file mode 100644 index 0000000..39978ca --- /dev/null +++ b/compass-tasks-base/utils/util.py @@ -0,0 +1,395 @@ +# Copyright 2014 Huawei Technologies Co. Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module to provider util functions in all compass code + + .. moduleauthor:: Xiaodong Wang <xiaodongwang@huawei.com> +""" + +import crypt +import datetime +import logging +import os +import os.path +import re +import setting_wrapper as setting +import sys +import warnings + + +def deprecated(func): + """This is a decorator which can be used to mark functions as deprecated. + + It will result in a warning being emitted when the function is used. + """ + def new_func(*args, **kwargs): + warnings.warn( + "Call to deprecated function %s." % func.__name__, + category=DeprecationWarning + ) + return func(*args, **kwargs) + + new_func.__name__ = func.__name__ + new_func.__doc__ = func.__doc__ + new_func.__dict__.update(func.__dict__) + return new_func + + +def parse_datetime(date_time, exception_class=Exception): + """Parse datetime str to get datetime object. + + The date time format is %Y-%m-%d %H:%M:%S + """ + try: + return datetime.datetime.strptime( + date_time, '%Y-%m-%d %H:%M:%S' + ) + except Exception as error: + logging.exception(error) + raise exception_class( + 'date time %s format is invalid' % date_time + ) + + +def parse_datetime_range(date_time_range, exception_class=Exception): + """parse datetime range str to pair of datetime objects. + + The date time range format is %Y-%m-%d %H:%M:%S,%Y-%m-%d %H:%M:%S + """ + try: + start, end = date_time_range.split(',') + except Exception as error: + logging.exception(error) + raise exception_class( + 'there is no `,` in date time range %s' % date_time_range + ) + if start: + start_datetime = parse_datetime(start, exception_class) + else: + start_datetime = None + if end: + end_datetime = parse_datetime(end, exception_class) + else: + end_datetime = None + return start_datetime, end_datetime + + +def parse_request_arg_dict(arg, exception_class=Exception): + """parse string to dict. + + The str is formatted like a=b;c=d and parsed to + {'a': 'b', 'c': 'd'} + """ + arg_dict = {} + arg_pairs = arg.split(';') + for arg_pair in arg_pairs: + try: + arg_name, arg_value = arg_pair.split('=', 1) + except Exception as error: + logging.exception(error) + raise exception_class( + 'there is no `=` in %s' % arg_pair + ) + arg_dict[arg_name] = arg_value + return arg_dict + + +def format_datetime(date_time): + """Generate string from datetime object.""" + return date_time.strftime("%Y-%m-%d %H:%M:%S") + + +def merge_dict(lhs, rhs, override=True): + """Merge nested right dict into left nested dict recursively. + + :param lhs: dict to be merged into. + :type lhs: dict + :param rhs: dict to merge from. + :type rhs: dict + :param override: the value in rhs overide the value in left if True. + :type override: boolean + """ + if not isinstance(lhs, dict) or not isinstance(rhs, dict): + if override: + return rhs + else: + return lhs + + for key, value in rhs.items(): + if key not in lhs: + lhs[key] = rhs[key] + else: + lhs[key] = merge_dict(lhs[key], value, override) + + return lhs + + +def recursive_merge_dict(name, all_dicts, parents): + """Recursively merge parent dict into base dict.""" + parent_name = parents.get(name, None) + base_dict = all_dicts.get(name, {}) + if not parent_name: + return base_dict + merged = recursive_merge_dict(parent_name, all_dicts, parents) + return merge_dict(base_dict, merged, override=False) + + +def encrypt(value, crypt_method=None): + """Get encrypted value.""" + if not crypt_method: + if hasattr(crypt, 'METHOD_MD5'): + crypt_method = crypt.METHOD_MD5 + else: + # for python2.7, copy python2.6 METHOD_MD5 logic here. + from random import choice + import string + + _saltchars = string.ascii_letters + string.digits + './' + + def _mksalt(): + """generate salt.""" + salt = '$1$' + salt += ''.join(choice(_saltchars) for _ in range(8)) + return salt + + crypt_method = _mksalt() + + return crypt.crypt(value, crypt_method) + + +def parse_time_interval(time_interval_str): + """parse string of time interval to time interval. + + supported time interval unit: ['d', 'w', 'h', 'm', 's'] + Examples: + time_interval_str: '3d 2h' time interval to 3 days and 2 hours. + """ + if not time_interval_str: + return 0 + + time_interval_tuple = [ + time_interval_element + for time_interval_element in time_interval_str.split(' ') + if time_interval_element + ] + time_interval_dict = {} + time_interval_unit_mapping = { + 'd': 'days', + 'w': 'weeks', + 'h': 'hours', + 'm': 'minutes', + 's': 'seconds' + } + for time_interval_element in time_interval_tuple: + mat = re.match(r'^([+-]?\d+)(w|d|h|m|s).*', time_interval_element) + if not mat: + continue + + time_interval_value = int(mat.group(1)) + time_interval_unit = time_interval_unit_mapping[mat.group(2)] + time_interval_dict[time_interval_unit] = ( + time_interval_dict.get(time_interval_unit, 0) + time_interval_value + ) + + time_interval = datetime.timedelta(**time_interval_dict) + if sys.version_info[0:2] > (2, 6): + return time_interval.total_seconds() + else: + return ( + time_interval.microseconds + ( + time_interval.seconds + time_interval.days * 24 * 3600 + ) * 1e6 + ) / 1e6 + + +def get_plugins_config_files(name, suffix=".conf"): + """walk through each of plugin to find all the config files in the""" + """name directory""" + + plugins_path = setting.PLUGINS_DIR + files = [] + if os.path.exists(plugins_path): + for plugin in os.listdir(plugins_path): + plugin_path = os.path.join(plugins_path, plugin) + plugin_config = os.path.join(plugin_path, name) + if os.path.exists(plugin_config): + for component in os.listdir(plugin_config): + if not component.endswith(suffix): + continue + files.append(os.path.join(plugin_config, component)) + return files + + +def load_configs( + config_dir, config_name_suffix='.conf', + env_globals={}, env_locals={} +): + """Load configurations from config dir.""" + """The config file could be in the config_dir or in plugins config_dir""" + """The plugins config_dir is formed as, for example /etc/compass/adapter""" + """Then the plugins config_dir is /etc/compass/plugins/xxx/adapter""" + + # TODO(Carl) instead of using config_dir, it should use a name such as + # adapter etc, however, doing it requires a lot client sites changes, + # will do it later. + + configs = [] + config_files = [] + config_dir = str(config_dir) + + """search for config_dir""" + if os.path.exists(config_dir): + for component in os.listdir(config_dir): + if not component.endswith(config_name_suffix): + continue + config_files.append(os.path.join(config_dir, component)) + + """search for plugins config_dir""" + index = config_dir.rfind("/") + + config_files.extend(get_plugins_config_files(config_dir[index + 1:], + config_name_suffix)) + + if not config_files: + logging.error('path %s and plugins does not exist', config_dir) + for path in config_files: + logging.debug('load config from %s', path) + config_globals = {} + config_globals.update(env_globals) + config_locals = {} + config_locals.update(env_locals) + try: + execfile(path, config_globals, config_locals) + except Exception as error: + logging.exception(error) + raise error + configs.append(config_locals) + return configs + + +def pretty_print(*contents): + """pretty print contents.""" + if len(contents) == 0: + print "" + else: + print "\n".join(content for content in contents) + + +def get_switch_machines_from_file(filename): + """get switch machines from file.""" + switches = [] + switch_machines = {} + with open(filename) as switch_file: + for line in switch_file: + line = line.strip() + if not line: + # ignore empty line + continue + + if line.startswith('#'): + # ignore comments + continue + + columns = [column for column in line.split(',')] + if not columns: + # ignore empty line + continue + + if columns[0] == 'switch': + (switch_ip, switch_vendor, switch_version, + switch_community, switch_state) = columns[1:] + switches.append({ + 'ip': switch_ip, + 'vendor': switch_vendor, + 'credentials': { + 'version': switch_version, + 'community': switch_community, + }, + 'state': switch_state, + }) + elif columns[0] == 'machine': + switch_ip, switch_port, mac = columns[1:] + switch_machines.setdefault(switch_ip, []).append({ + 'mac': mac, + 'port': switch_port, + }) + + return (switches, switch_machines) + + +def execute_cli_by_ssh(cmd, host, username, password=None, + keyfile='/root/.ssh/id_rsa', nowait=False): + """SSH to execute script on remote machine + + :param host: ip of the remote machine + :param username: username to access the remote machine + :param password: password to access the remote machine + :param cmd: command to execute + + """ + if not cmd: + logging.error("No command found!") + raise Exception('No command found!') + + if nowait: + cmd = "nohup %s >/dev/null 2>&1 &" % cmd + + stdin = None + stdout = None + stderr = None + try: + import paramiko + from paramiko import ssh_exception + + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + if password: + client.connect(host, username=username, password=password) + else: + client.load_system_host_keys() + client.connect( + host, username=username, + key_filename=keyfile, look_for_keys=True + ) + stdin, stdout, stderr = client.exec_command(cmd) + result = stdout.readlines() + logging.info("result of command '%s' is '%s'!" % (cmd, result)) + return result + + except ImportError: + err_msg = "Cannot find Paramiko package!" + logging.error(err_msg) + raise ImportError(err_msg) + + except (ssh_exception.BadHostKeyException, + ssh_exception.AuthenticationException, + ssh_exception.SSHException): + + err_msg = 'SSH connection error or command execution failed!' + logging.error(err_msg) + raise Exception(err_msg) + + except Exception as exc: + logging.error( + 'Failed to execute command "%s", exception is %s' % (cmd, exc) + ) + raise Exception(exc) + + finally: + for resource in [stdin, stdout, stderr]: + if resource: + resource.close() + + client.close() |