diff options
Diffstat (limited to 'app')
-rw-r--r-- | app/discover/fetchers/db/db_access.py | 27 | ||||
-rw-r--r-- | app/test/fetch/test_fetch.py | 2 |
2 files changed, 20 insertions, 9 deletions
diff --git a/app/discover/fetchers/db/db_access.py b/app/discover/fetchers/db/db_access.py index 0174c4b..47f4f9e 100644 --- a/app/discover/fetchers/db/db_access.py +++ b/app/discover/fetchers/db/db_access.py @@ -11,6 +11,7 @@ import mysql.connector from discover.configuration import Configuration from discover.fetcher import Fetcher +from discover.scan_error import ScanError from utils.string_utils import jsonify @@ -27,15 +28,7 @@ class DbAccess(Fetcher): self.config = Configuration() self.conf = self.config.get("mysql") self.connect_to_db() - cursor = DbAccess.conn.cursor(dictionary=True) - try: - # check if DB schema 'neutron' exists - cursor.execute("SELECT COUNT(*) FROM neutron.agents") - for row in cursor: - pass - self.neutron_db = "neutron" - except (AttributeError, mysql.connector.errors.ProgrammingError): - self.neutron_db = "ml2_neutron" + self.neutron_db = self.get_neutron_db_name() def db_connect(self, _host, _port, _user, _pwd, _database): if DbAccess.conn: @@ -54,6 +47,22 @@ class DbAccess(Fetcher): return DbAccess.query_count_per_con = 0 + @staticmethod + def get_neutron_db_name(): + # check if DB schema 'neutron' exists + cursor = DbAccess.conn.cursor(dictionary=True) + cursor.execute('SHOW DATABASES') + matches = [] + for row in cursor: + if 'neutron' in row.get('Database', ''): + matches.append(row) + if not matches: + raise ScanError('Unable to find Neutron schema in OpenStack DB') + if len(matches) > 1: + raise ScanError('Found multiple possible names for Neutron schema ' + 'in OpenStack DB') + return matches[0] + def connect_to_db(self, force=False): if DbAccess.conn: if not force: diff --git a/app/test/fetch/test_fetch.py b/app/test/fetch/test_fetch.py index 55d7d4c..d40a52c 100644 --- a/app/test/fetch/test_fetch.py +++ b/app/test/fetch/test_fetch.py @@ -59,6 +59,8 @@ class TestFetch(unittest.TestCase): self.inv = InventoryMgr() self.inv.set_collections(self.inventory_collection) DbAccess.conn = MagicMock() + DbAccess.get_neutron_db_name = MagicMock() + DbAccess.get_neutron_db_name.return_value = "neutron" SshConnection.connect = MagicMock() SshConnection.check_definitions = MagicMock() SshConn.check_definitions = MagicMock() |