aboutsummaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/tests/unit/ksfixtures/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'keystone-moon/keystone/tests/unit/ksfixtures/database.py')
-rw-r--r--keystone-moon/keystone/tests/unit/ksfixtures/database.py75
1 files changed, 64 insertions, 11 deletions
diff --git a/keystone-moon/keystone/tests/unit/ksfixtures/database.py b/keystone-moon/keystone/tests/unit/ksfixtures/database.py
index 6f23a99d..52c35cee 100644
--- a/keystone-moon/keystone/tests/unit/ksfixtures/database.py
+++ b/keystone-moon/keystone/tests/unit/ksfixtures/database.py
@@ -28,12 +28,13 @@ CONF = cfg.CONF
def run_once(f):
"""A decorator to ensure the decorated function is only executed once.
- The decorated function cannot expect any arguments.
+ The decorated function is assumed to have a one parameter.
+
"""
@functools.wraps(f)
- def wrapper():
+ def wrapper(one):
if not wrapper.already_ran:
- f()
+ f(one)
wrapper.already_ran = True
wrapper.already_ran = False
return wrapper
@@ -51,7 +52,7 @@ def initialize_sql_session():
@run_once
-def _load_sqlalchemy_models():
+def _load_sqlalchemy_models(version_specifiers):
"""Find all modules containing SQLAlchemy models and import them.
This creates more consistent, deterministic test runs because tables
@@ -66,6 +67,24 @@ def _load_sqlalchemy_models():
as more models are imported. Importing all models at the start of
the test run avoids this problem.
+ version_specifiers is a dict that contains any specific driver versions
+ that have been requested. The dict is of the form:
+
+ {<module_name> : {'versioned_backend' : <name of backend requested>,
+ 'versionless_backend' : <name of default backend>}
+ }
+
+ For example:
+
+ {'keystone.assignment': {'versioned_backend' : 'V8_backends',
+ 'versionless_backend' : 'backends'},
+ 'keystone.identity': {'versioned_backend' : 'V9_backends',
+ 'versionless_backend' : 'backends'}
+ }
+
+ The version_specifiers will be used to load the correct driver. The
+ algorithm for this assumes that versioned drivers begin in 'V'.
+
"""
keystone_root = os.path.normpath(os.path.join(
os.path.dirname(__file__), '..', '..', '..'))
@@ -78,25 +97,59 @@ def _load_sqlalchemy_models():
# The root will be prefixed with an instance of os.sep, which will
# make the root after replacement '.<root>', the 'keystone' part
# of the module path is always added to the front
- module_name = ('keystone.%s.sql' %
+ module_root = ('keystone.%s' %
root.replace(os.sep, '.').lstrip('.'))
+ module_components = module_root.split('.')
+ module_without_backends = ''
+ for x in range(0, len(module_components) - 1):
+ module_without_backends += module_components[x] + '.'
+ module_without_backends = module_without_backends.rstrip('.')
+ this_backend = module_components[len(module_components) - 1]
+
+ # At this point module_without_backends might be something like
+ # 'keystone.assignment', while this_backend might be something
+ # 'V8_backends'.
+
+ if module_without_backends.startswith('keystone.contrib'):
+ # All the sql modules have now been moved into the core tree
+ # so no point in loading these again here (and, in fact, doing
+ # so might break trying to load a versioned driver.
+ continue
+
+ if module_without_backends in version_specifiers:
+ # OK, so there is a request for a specific version of this one.
+ # We therefore should skip any other versioned backend as well
+ # as the non-versioned one.
+ version = version_specifiers[module_without_backends]
+ if ((this_backend != version['versioned_backend'] and
+ this_backend.startswith('V')) or
+ this_backend == version['versionless_backend']):
+ continue
+ else:
+ # No versioned driver requested, so ignore any that are
+ # versioned
+ if this_backend.startswith('V'):
+ continue
+
+ module_name = module_root + '.sql'
__import__(module_name)
class Database(fixtures.Fixture):
- """A fixture for setting up and tearing down a database.
-
- """
+ """A fixture for setting up and tearing down a database."""
- def __init__(self):
+ def __init__(self, version_specifiers=None):
super(Database, self).__init__()
initialize_sql_session()
- _load_sqlalchemy_models()
+ if version_specifiers is None:
+ version_specifiers = {}
+ _load_sqlalchemy_models(version_specifiers)
def setUp(self):
super(Database, self).setUp()
- self.engine = sql.get_engine()
+ with sql.session_for_write() as session:
+ self.engine = session.get_bind()
self.addCleanup(sql.cleanup)
sql.ModelBase.metadata.create_all(bind=self.engine)
self.addCleanup(sql.ModelBase.metadata.drop_all, bind=self.engine)