From ca276f452540f68c08cb3df9049e9e7876364dac Mon Sep 17 00:00:00 2001 From: spisarski Date: Thu, 27 Jul 2017 10:27:14 -0600 Subject: Ensure library and tests close all necessary resources. The SNAPS-OO library and tests had left open files, ssh, and scp connections. These have all now been wrapped with try/finally blocks. JIRA: SNAPS-152 Change-Id: I43e09978b5c075bd78ff3279c0799556b8758878 Signed-off-by: spisarski --- snaps/file_utils.py | 75 ++++++++++++++-------- snaps/openstack/create_instance.py | 1 + snaps/openstack/tests/conf/os_credentials_tests.py | 24 +++---- snaps/openstack/tests/create_instance_tests.py | 5 +- snaps/openstack/tests/create_keypairs_tests.py | 34 ++++++++-- snaps/openstack/utils/glance_utils.py | 38 ++++++----- snaps/openstack/utils/nova_utils.py | 49 +++++++++----- snaps/playbook_runner.py | 37 +++++++---- snaps/provisioning/tests/ansible_utils_tests.py | 52 +++++++++++---- snaps/tests/file_utils_tests.py | 6 +- 10 files changed, 218 insertions(+), 103 deletions(-) diff --git a/snaps/file_utils.py b/snaps/file_utils.py index a7ed13c..ff2f1b3 100644 --- a/snaps/file_utils.py +++ b/snaps/file_utils.py @@ -32,7 +32,8 @@ logger = logging.getLogger('file_utils') def file_exists(file_path): """ - Returns True if the image file already exists and throws an exception if the path is a directory + Returns True if the image file already exists and throws an exception if + the path is a directory :return: """ if os.path.exists(file_path): @@ -55,7 +56,7 @@ def download(url, dest_path, name=None): dest = dest_path + '/' + name logger.debug('Downloading file from - ' + url) # Override proxy settings to use localhost to download file - f = None + download_file = None if not os.path.isdir(dest_path): try: @@ -63,14 +64,14 @@ def download(url, dest_path, name=None): except: raise try: - with open(dest, 'wb') as f: - logger.debug('Saving file to - ' + os.path.abspath(f.name)) + with open(dest, 'wb') as download_file: + logger.debug('Saving file to - ' + os.path.abspath(download_file.name)) response = __get_url_response(url) - f.write(response.read()) - return f + download_file.write(response.read()) + return download_file finally: - if f: - f.close() + if download_file: + download_file.close() def get_content_length(url): @@ -102,32 +103,45 @@ def read_yaml(config_file_path): :return: a dictionary """ logger.debug('Attempting to load configuration file - ' + config_file_path) - with open(config_file_path) as config_file: - config = yaml.safe_load(config_file) - logger.info('Loaded configuration') - config_file.close() - logger.info('Closing configuration file') - return config + config_file = None + try: + with open(config_file_path) as config_file: + config = yaml.safe_load(config_file) + logger.info('Loaded configuration') + return config + finally: + if config_file: + logger.info('Closing configuration file') + config_file.close() def read_os_env_file(os_env_filename): """ Reads the OS environment source file and returns a map of each key/value - Will ignore lines beginning with a '#' and will replace any single or double quotes contained within the value + Will ignore lines beginning with a '#' and will replace any single or + double quotes contained within the value :param os_env_filename: The name of the OS environment file to read :return: a dictionary """ if os_env_filename: - logger.info('Attempting to read OS environment file - ' + os_env_filename) + logger.info('Attempting to read OS environment file - %s', + os_env_filename) out = {} - for line in open(os_env_filename): - line = line.lstrip() - if not line.startswith('#') and line.startswith('export '): - line = line.lstrip('export ').strip() - tokens = line.split('=') - if len(tokens) > 1: - # Remove leading and trailing ' & " characters from value - out[tokens[0]] = tokens[1].lstrip('\'').lstrip('\"').rstrip('\'').rstrip('\"') + env_file = None + try: + env_file = open(os_env_filename) + for line in env_file: + line = line.lstrip() + if not line.startswith('#') and line.startswith('export '): + line = line.lstrip('export ').strip() + tokens = line.split('=') + if len(tokens) > 1: + # Remove leading and trailing ' & " characters from + # value + out[tokens[0]] = tokens[1].lstrip('\'').lstrip('\"').rstrip('\'').rstrip('\"') + finally: + if env_file: + env_file.close() return out @@ -138,7 +152,12 @@ def read_file(filename): :return: """ out = str() - for line in open(filename): - out += line - - return out + the_file = None + try: + the_file = open(filename) + for line in the_file: + out += line + return out + finally: + if the_file: + the_file.close() diff --git a/snaps/openstack/create_instance.py b/snaps/openstack/create_instance.py index d5917a8..997b5a5 100644 --- a/snaps/openstack/create_instance.py +++ b/snaps/openstack/create_instance.py @@ -618,6 +618,7 @@ class OpenStackVmInstance: if len(self.__floating_ips) > 0: ssh = self.ssh_client() if ssh: + ssh.close() return True return False diff --git a/snaps/openstack/tests/conf/os_credentials_tests.py b/snaps/openstack/tests/conf/os_credentials_tests.py index e7c34b9..4a2ce3d 100644 --- a/snaps/openstack/tests/conf/os_credentials_tests.py +++ b/snaps/openstack/tests/conf/os_credentials_tests.py @@ -56,17 +56,17 @@ class ProxySettingsUnitTests(unittest.TestCase): def test_minimum(self): proxy_settings = ProxySettings(host='foo', port=1234) self.assertEqual('foo', proxy_settings.host) - self.assertEqual(1234, proxy_settings.port) + self.assertEqual('1234', proxy_settings.port) self.assertEqual('foo', proxy_settings.https_host) - self.assertEqual(1234, proxy_settings.https_port) + self.assertEqual('1234', proxy_settings.https_port) self.assertIsNone(proxy_settings.ssh_proxy_cmd) def test_minimum_kwargs(self): proxy_settings = ProxySettings(**{'host': 'foo', 'port': 1234}) self.assertEqual('foo', proxy_settings.host) - self.assertEqual(1234, proxy_settings.port) + self.assertEqual('1234', proxy_settings.port) self.assertEqual('foo', proxy_settings.https_host) - self.assertEqual(1234, proxy_settings.https_port) + self.assertEqual('1234', proxy_settings.https_port) self.assertIsNone(proxy_settings.ssh_proxy_cmd) def test_all(self): @@ -74,9 +74,9 @@ class ProxySettingsUnitTests(unittest.TestCase): host='foo', port=1234, https_host='bar', https_port=2345, ssh_proxy_cmd='proxy command') self.assertEqual('foo', proxy_settings.host) - self.assertEqual(1234, proxy_settings.port) + self.assertEqual('1234', proxy_settings.port) self.assertEqual('bar', proxy_settings.https_host) - self.assertEqual(2345, proxy_settings.https_port) + self.assertEqual('2345', proxy_settings.https_port) self.assertEqual('proxy command', proxy_settings.ssh_proxy_cmd) def test_all_kwargs(self): @@ -84,9 +84,9 @@ class ProxySettingsUnitTests(unittest.TestCase): **{'host': 'foo', 'port': 1234, 'https_host': 'bar', 'https_port': 2345, 'ssh_proxy_cmd': 'proxy command'}) self.assertEqual('foo', proxy_settings.host) - self.assertEqual(1234, proxy_settings.port) + self.assertEqual('1234', proxy_settings.port) self.assertEqual('bar', proxy_settings.https_host) - self.assertEqual(2345, proxy_settings.https_port) + self.assertEqual('2345', proxy_settings.https_port) self.assertEqual('proxy command', proxy_settings.ssh_proxy_cmd) @@ -245,7 +245,7 @@ class OSCredsUnitTests(unittest.TestCase): self.assertEqual('admin', os_creds.interface) self.assertFalse(os_creds.cacert) self.assertEqual('foo', os_creds.proxy_settings.host) - self.assertEqual(1234, os_creds.proxy_settings.port) + self.assertEqual('1234', os_creds.proxy_settings.port) self.assertIsNone(os_creds.proxy_settings.ssh_proxy_cmd) self.assertIsNone(os_creds.region_name) @@ -269,7 +269,7 @@ class OSCredsUnitTests(unittest.TestCase): self.assertEqual('admin', os_creds.interface) self.assertFalse(os_creds.cacert) self.assertEqual('foo', os_creds.proxy_settings.host) - self.assertEqual(1234, os_creds.proxy_settings.port) + self.assertEqual('1234', os_creds.proxy_settings.port) self.assertIsNone(os_creds.proxy_settings.ssh_proxy_cmd) self.assertEqual('test_region', os_creds.region_name) @@ -290,7 +290,7 @@ class OSCredsUnitTests(unittest.TestCase): self.assertEqual('admin', os_creds.interface) self.assertFalse(os_creds.cacert) self.assertEqual('foo', os_creds.proxy_settings.host) - self.assertEqual(1234, os_creds.proxy_settings.port) + self.assertEqual('1234', os_creds.proxy_settings.port) self.assertIsNone(os_creds.proxy_settings.ssh_proxy_cmd) def test_proxy_settings_dict_kwargs(self): @@ -312,6 +312,6 @@ class OSCredsUnitTests(unittest.TestCase): self.assertEqual('admin', os_creds.interface) self.assertFalse(os_creds.cacert) self.assertEqual('foo', os_creds.proxy_settings.host) - self.assertEqual(1234, os_creds.proxy_settings.port) + self.assertEqual('1234', os_creds.proxy_settings.port) self.assertIsNone(os_creds.proxy_settings.ssh_proxy_cmd) self.assertEqual('test_region', os_creds.region_name) diff --git a/snaps/openstack/tests/create_instance_tests.py b/snaps/openstack/tests/create_instance_tests.py index 75b0ed3..1922146 100644 --- a/snaps/openstack/tests/create_instance_tests.py +++ b/snaps/openstack/tests/create_instance_tests.py @@ -1717,7 +1717,10 @@ def validate_ssh_client(instance_creator): if ssh_active: ssh_client = instance_creator.ssh_client() if ssh_client: - out = ssh_client.exec_command('pwd')[1] + try: + out = ssh_client.exec_command('pwd')[1] + finally: + ssh_client.close() else: return False diff --git a/snaps/openstack/tests/create_keypairs_tests.py b/snaps/openstack/tests/create_keypairs_tests.py index 0b35095..7b75d05 100644 --- a/snaps/openstack/tests/create_keypairs_tests.py +++ b/snaps/openstack/tests/create_keypairs_tests.py @@ -285,9 +285,15 @@ class CreateKeypairsTests(OSIntegrationTestCase): self.keypair_creator.get_keypair()) self.assertEqual(self.keypair_creator.get_keypair(), keypair) - file_key = open(os.path.expanduser(self.pub_file_path)).read() - self.assertEqual(self.keypair_creator.get_keypair().public_key, - file_key) + pub_file = None + try: + pub_file = open(os.path.expanduser(self.pub_file_path)) + file_key = pub_file.read() + self.assertEqual(self.keypair_creator.get_keypair().public_key, + file_key) + finally: + if pub_file: + pub_file.close() def test_create_keypair_save_both(self): """ @@ -305,7 +311,16 @@ class CreateKeypairsTests(OSIntegrationTestCase): self.keypair_creator.get_keypair()) self.assertEqual(self.keypair_creator.get_keypair(), keypair) - file_key = open(os.path.expanduser(self.pub_file_path)).read() + pub_file = None + try: + pub_file = open(os.path.expanduser(self.pub_file_path)) + file_key = pub_file.read() + self.assertEqual(self.keypair_creator.get_keypair().public_key, + file_key) + finally: + if pub_file: + pub_file.close() + self.assertEqual(self.keypair_creator.get_keypair().public_key, file_key) @@ -328,7 +343,16 @@ class CreateKeypairsTests(OSIntegrationTestCase): self.keypair_creator.get_keypair()) self.assertEqual(self.keypair_creator.get_keypair(), keypair) - file_key = open(os.path.expanduser(self.pub_file_path)).read() + pub_file = None + try: + pub_file = open(os.path.expanduser(self.pub_file_path)) + file_key = pub_file.read() + self.assertEqual(self.keypair_creator.get_keypair().public_key, + file_key) + finally: + if pub_file: + pub_file.close() + self.assertEqual(self.keypair_creator.get_keypair().public_key, file_key) diff --git a/snaps/openstack/utils/glance_utils.py b/snaps/openstack/utils/glance_utils.py index 49bfe95..ad9c5e5 100644 --- a/snaps/openstack/utils/glance_utils.py +++ b/snaps/openstack/utils/glance_utils.py @@ -124,22 +124,30 @@ def __create_image_v1(glance, image_settings): 'name': image_settings.name, 'disk_format': image_settings.format, 'container_format': 'bare', 'is_public': image_settings.public} - if image_settings.extra_properties: - kwargs['properties'] = image_settings.extra_properties - - if image_settings.url: - kwargs['location'] = image_settings.url - elif image_settings.image_file: - image_file = open(image_settings.image_file, 'rb') - kwargs['data'] = image_file - else: - logger.warn('Unable to create image with name - %s. No file or URL', - image_settings.name) - return None + image_file = None - created_image = glance.images.create(**kwargs) - return Image(name=image_settings.name, image_id=created_image.id, - size=created_image.size, properties=created_image.properties) + try: + if image_settings.extra_properties: + kwargs['properties'] = image_settings.extra_properties + + if image_settings.url: + kwargs['location'] = image_settings.url + elif image_settings.image_file: + image_file = open(image_settings.image_file, 'rb') + kwargs['data'] = image_file + else: + logger.warn( + 'Unable to create image with name - %s. No file or URL', + image_settings.name) + return None + + created_image = glance.images.create(**kwargs) + return Image(name=image_settings.name, image_id=created_image.id, + size=created_image.size, + properties=created_image.properties) + finally: + if image_file: + image_file.close() def __create_image_v2(glance, image_settings): diff --git a/snaps/openstack/utils/nova_utils.py b/snaps/openstack/utils/nova_utils.py index ab434f1..b148bc5 100644 --- a/snaps/openstack/utils/nova_utils.py +++ b/snaps/openstack/utils/nova_utils.py @@ -232,12 +232,18 @@ def save_keys_to_files(keys=None, pub_file_path=None, priv_file_path=None): if not os.path.isdir(pub_dir): os.mkdir(pub_dir) - public_handle = open(pub_expand_file, 'wb') - public_bytes = keys.public_key().public_bytes( - serialization.Encoding.OpenSSH, - serialization.PublicFormat.OpenSSH) - public_handle.write(public_bytes) - public_handle.close() + + public_handle = None + try: + public_handle = open(pub_expand_file, 'wb') + public_bytes = keys.public_key().public_bytes( + serialization.Encoding.OpenSSH, + serialization.PublicFormat.OpenSSH) + public_handle.write(public_bytes) + finally: + if public_handle: + public_handle.close() + os.chmod(pub_expand_file, 0o400) logger.info("Saved public key to - " + pub_expand_file) if priv_file_path: @@ -246,13 +252,19 @@ def save_keys_to_files(keys=None, pub_file_path=None, priv_file_path=None): priv_dir = os.path.dirname(priv_expand_file) if not os.path.isdir(priv_dir): os.mkdir(priv_dir) - private_handle = open(priv_expand_file, 'wb') - private_handle.write( - keys.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption())) - private_handle.close() + + private_handle = None + try: + private_handle = open(priv_expand_file, 'wb') + private_handle.write( + keys.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption())) + finally: + if private_handle: + private_handle.close() + os.chmod(priv_expand_file, 0o400) logger.info("Saved private key to - " + priv_expand_file) @@ -265,9 +277,14 @@ def upload_keypair_file(nova, name, file_path): :param file_path: the path to the public key file :return: the keypair object """ - with open(os.path.expanduser(file_path), 'rb') as fpubkey: - logger.info('Saving keypair to - ' + file_path) - return upload_keypair(nova, name, fpubkey.read()) + fpubkey = None + try: + with open(os.path.expanduser(file_path), 'rb') as fpubkey: + logger.info('Saving keypair to - ' + file_path) + return upload_keypair(nova, name, fpubkey.read()) + finally: + if fpubkey: + fpubkey.close() def upload_keypair(nova, name, key): diff --git a/snaps/playbook_runner.py b/snaps/playbook_runner.py index 3710309..4dba550 100644 --- a/snaps/playbook_runner.py +++ b/snaps/playbook_runner.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016 Cable Television Laboratories, Inc. ("CableLabs") +# Copyright (c) 2017 Cable Television Laboratories, Inc. ("CableLabs") # and others. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,7 +27,8 @@ logger = logging.getLogger('playbook_runner') def main(parsed_args): """ - Uses ansible_utils for applying Ansible Playbooks to machines with a private key + Uses ansible_utils for applying Ansible Playbooks to machines with a + private key """ logging.basicConfig(level=logging.DEBUG) logger.info('Starting Playbook Runner') @@ -35,24 +36,36 @@ def main(parsed_args): proxy_settings = None if parsed_args.http_proxy: tokens = re.split(':', parsed_args.http_proxy) - proxy_settings = ProxySettings(tokens[0], tokens[1], parsed_args.ssh_proxy_cmd) + proxy_settings = ProxySettings(host=tokens[0], port=tokens[1], + ssh_proxy_cmd=parsed_args.ssh_proxy_cmd) # Ensure can get an SSH client - ansible_utils.ssh_client(parsed_args.ip_addr, parsed_args.host_user, parsed_args.priv_key, proxy_settings) + ssh = ansible_utils.ssh_client(parsed_args.ip_addr, parsed_args.host_user, + parsed_args.priv_key, proxy_settings) + if ssh: + ssh.close() - retval = ansible_utils.apply_playbook(parsed_args.playbook, [parsed_args.ip_addr], parsed_args.host_user, - parsed_args.priv_key, variables={'name': 'Foo'}, proxy_setting=proxy_settings) + retval = ansible_utils.apply_playbook( + parsed_args.playbook, [parsed_args.ip_addr], parsed_args.host_user, + parsed_args.priv_key, variables={'name': 'Foo'}, + proxy_setting=proxy_settings) exit(retval) if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-a', '--ip-addr', dest='ip_addr', required=True, help='The Host IP Address') - parser.add_argument('-k', '--priv-key', dest='priv_key', required=True, help='The location of the private key file') - parser.add_argument('-u', '--host-user', dest='host_user', required=True, help='Host user account') - parser.add_argument('-b', '--playbook', dest='playbook', required=True, help='Playbook Location') - parser.add_argument('-p', '--http-proxy', dest='http_proxy', required=False, help=':') - parser.add_argument('-s', '--ssh-proxy-cmd', dest='ssh_proxy_cmd', required=False) + parser.add_argument('-a', '--ip-addr', dest='ip_addr', required=True, + help='The Host IP Address') + parser.add_argument('-k', '--priv-key', dest='priv_key', required=True, + help='The location of the private key file') + parser.add_argument('-u', '--host-user', dest='host_user', required=True, + help='Host user account') + parser.add_argument('-b', '--playbook', dest='playbook', required=True, + help='Playbook Location') + parser.add_argument('-p', '--http-proxy', dest='http_proxy', + required=False, help=':') + parser.add_argument('-s', '--ssh-proxy-cmd', dest='ssh_proxy_cmd', + required=False) args = parser.parse_args() main(args) diff --git a/snaps/provisioning/tests/ansible_utils_tests.py b/snaps/provisioning/tests/ansible_utils_tests.py index 203ba33..da056b2 100644 --- a/snaps/provisioning/tests/ansible_utils_tests.py +++ b/snaps/provisioning/tests/ansible_utils_tests.py @@ -239,9 +239,14 @@ class AnsibleProvisioningTests(OSIntegrationTestCase): ssh_client = self.inst_creator.ssh_client() self.assertIsNotNone(ssh_client) - out = ssh_client.exec_command('pwd')[1].channel.in_buffer.read(1024) - self.assertIsNotNone(out) - self.assertGreater(len(out), 1) + + try: + out = ssh_client.exec_command('pwd')[1].channel.in_buffer.read( + 1024) + self.assertIsNotNone(out) + self.assertGreater(len(out), 1) + finally: + ssh_client.close() # Need to use the first floating IP as subsequent ones are currently # broken with Apex CO @@ -257,14 +262,25 @@ class AnsibleProvisioningTests(OSIntegrationTestCase): ssh = ansible_utils.ssh_client(ip, user, priv_key, self.os_creds.proxy_settings) self.assertIsNotNone(ssh) - scp = SCPClient(ssh.get_transport()) - scp.get('~/hello.txt', self.test_file_local_path) + + try: + scp = SCPClient(ssh.get_transport()) + scp.get('~/hello.txt', self.test_file_local_path) + finally: + scp.close() + ssh.close() self.assertTrue(os.path.isfile(self.test_file_local_path)) - with open(self.test_file_local_path) as f: - file_contents = f.readline() - self.assertEqual('Hello World!', file_contents) + test_file = None + + try: + with open(self.test_file_local_path) as test_file: + file_contents = test_file.readline() + self.assertEqual('Hello World!', file_contents) + finally: + if test_file: + test_file.close() def test_apply_template_playbook(self): """ @@ -310,11 +326,21 @@ class AnsibleProvisioningTests(OSIntegrationTestCase): ssh = ansible_utils.ssh_client(ip, user, priv_key, self.os_creds.proxy_settings) self.assertIsNotNone(ssh) - scp = SCPClient(ssh.get_transport()) - scp.get('/tmp/hello.txt', self.test_file_local_path) + + try: + scp = SCPClient(ssh.get_transport()) + scp.get('/tmp/hello.txt', self.test_file_local_path) + finally: + scp.close() + ssh.close() self.assertTrue(os.path.isfile(self.test_file_local_path)) - with open(self.test_file_local_path) as f: - file_contents = f.readline() - self.assertEqual('Hello Foo!', file_contents) + test_file = None + try: + with open(self.test_file_local_path) as test_file: + file_contents = test_file.readline() + self.assertEqual('Hello Foo!', file_contents) + finally: + if test_file: + test_file.close() diff --git a/snaps/tests/file_utils_tests.py b/snaps/tests/file_utils_tests.py index f3a622a..ef8b4ae 100644 --- a/snaps/tests/file_utils_tests.py +++ b/snaps/tests/file_utils_tests.py @@ -37,10 +37,14 @@ class FileUtilsTests(unittest.TestCase): os.makedirs(self.test_dir) self.tmpFile = self.test_dir + '/bar.txt' + self.tmp_file_opened = None if not os.path.exists(self.tmpFile): - open(self.tmpFile, 'wb') + self.tmp_file_opened = open(self.tmpFile, 'wb') def tearDown(self): + if self.tmp_file_opened: + self.tmp_file_opened.close() + if os.path.exists(self.test_dir) and os.path.isdir(self.test_dir): shutil.rmtree(self.tmp_dir) -- cgit 1.2.3-korg