diff options
Diffstat (limited to 'keystone-moon/keystone/tests/unit/ksfixtures/database.py')
-rw-r--r-- | keystone-moon/keystone/tests/unit/ksfixtures/database.py | 75 |
1 files changed, 64 insertions, 11 deletions
diff --git a/keystone-moon/keystone/tests/unit/ksfixtures/database.py b/keystone-moon/keystone/tests/unit/ksfixtures/database.py index 6f23a99d..52c35cee 100644 --- a/keystone-moon/keystone/tests/unit/ksfixtures/database.py +++ b/keystone-moon/keystone/tests/unit/ksfixtures/database.py @@ -28,12 +28,13 @@ CONF = cfg.CONF def run_once(f): """A decorator to ensure the decorated function is only executed once. - The decorated function cannot expect any arguments. + The decorated function is assumed to have a one parameter. + """ @functools.wraps(f) - def wrapper(): + def wrapper(one): if not wrapper.already_ran: - f() + f(one) wrapper.already_ran = True wrapper.already_ran = False return wrapper @@ -51,7 +52,7 @@ def initialize_sql_session(): @run_once -def _load_sqlalchemy_models(): +def _load_sqlalchemy_models(version_specifiers): """Find all modules containing SQLAlchemy models and import them. This creates more consistent, deterministic test runs because tables @@ -66,6 +67,24 @@ def _load_sqlalchemy_models(): as more models are imported. Importing all models at the start of the test run avoids this problem. + version_specifiers is a dict that contains any specific driver versions + that have been requested. The dict is of the form: + + {<module_name> : {'versioned_backend' : <name of backend requested>, + 'versionless_backend' : <name of default backend>} + } + + For example: + + {'keystone.assignment': {'versioned_backend' : 'V8_backends', + 'versionless_backend' : 'backends'}, + 'keystone.identity': {'versioned_backend' : 'V9_backends', + 'versionless_backend' : 'backends'} + } + + The version_specifiers will be used to load the correct driver. The + algorithm for this assumes that versioned drivers begin in 'V'. + """ keystone_root = os.path.normpath(os.path.join( os.path.dirname(__file__), '..', '..', '..')) @@ -78,25 +97,59 @@ def _load_sqlalchemy_models(): # The root will be prefixed with an instance of os.sep, which will # make the root after replacement '.<root>', the 'keystone' part # of the module path is always added to the front - module_name = ('keystone.%s.sql' % + module_root = ('keystone.%s' % root.replace(os.sep, '.').lstrip('.')) + module_components = module_root.split('.') + module_without_backends = '' + for x in range(0, len(module_components) - 1): + module_without_backends += module_components[x] + '.' + module_without_backends = module_without_backends.rstrip('.') + this_backend = module_components[len(module_components) - 1] + + # At this point module_without_backends might be something like + # 'keystone.assignment', while this_backend might be something + # 'V8_backends'. + + if module_without_backends.startswith('keystone.contrib'): + # All the sql modules have now been moved into the core tree + # so no point in loading these again here (and, in fact, doing + # so might break trying to load a versioned driver. + continue + + if module_without_backends in version_specifiers: + # OK, so there is a request for a specific version of this one. + # We therefore should skip any other versioned backend as well + # as the non-versioned one. + version = version_specifiers[module_without_backends] + if ((this_backend != version['versioned_backend'] and + this_backend.startswith('V')) or + this_backend == version['versionless_backend']): + continue + else: + # No versioned driver requested, so ignore any that are + # versioned + if this_backend.startswith('V'): + continue + + module_name = module_root + '.sql' __import__(module_name) class Database(fixtures.Fixture): - """A fixture for setting up and tearing down a database. - - """ + """A fixture for setting up and tearing down a database.""" - def __init__(self): + def __init__(self, version_specifiers=None): super(Database, self).__init__() initialize_sql_session() - _load_sqlalchemy_models() + if version_specifiers is None: + version_specifiers = {} + _load_sqlalchemy_models(version_specifiers) def setUp(self): super(Database, self).setUp() - self.engine = sql.get_engine() + with sql.session_for_write() as session: + self.engine = session.get_bind() self.addCleanup(sql.cleanup) sql.ModelBase.metadata.create_all(bind=self.engine) self.addCleanup(sql.ModelBase.metadata.drop_all, bind=self.engine) |