aboutsummaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/common/sql/migration_helpers.py
diff options
context:
space:
mode:
Diffstat (limited to 'keystone-moon/keystone/common/sql/migration_helpers.py')
-rw-r--r--keystone-moon/keystone/common/sql/migration_helpers.py129
1 files changed, 74 insertions, 55 deletions
diff --git a/keystone-moon/keystone/common/sql/migration_helpers.py b/keystone-moon/keystone/common/sql/migration_helpers.py
index aaa59f70..40c1fbb5 100644
--- a/keystone-moon/keystone/common/sql/migration_helpers.py
+++ b/keystone-moon/keystone/common/sql/migration_helpers.py
@@ -21,37 +21,25 @@ import migrate
from migrate import exceptions
from oslo_config import cfg
from oslo_db.sqlalchemy import migration
-from oslo_serialization import jsonutils
from oslo_utils import importutils
import six
import sqlalchemy
from keystone.common import sql
-from keystone.common.sql import migrate_repo
from keystone import contrib
from keystone import exception
from keystone.i18n import _
CONF = cfg.CONF
-DEFAULT_EXTENSIONS = ['endpoint_filter',
- 'endpoint_policy',
- 'federation',
- 'oauth1',
- 'revoke',
- ]
-
-
-def get_default_domain():
- # Return the reference used for the default domain structure during
- # sql migrations.
- return {
- 'id': CONF.identity.default_domain_id,
- 'name': 'Default',
- 'enabled': True,
- 'extra': jsonutils.dumps({'description': 'Owns users and tenants '
- '(i.e. projects) available '
- 'on Identity API v2.'})}
+DEFAULT_EXTENSIONS = []
+
+MIGRATED_EXTENSIONS = ['endpoint_policy',
+ 'federation',
+ 'oauth1',
+ 'revoke',
+ 'endpoint_filter'
+ ]
# Different RDBMSs use different schemes for naming the Foreign Key
@@ -117,9 +105,8 @@ def rename_tables_with_constraints(renames, constraints, engine):
`renames` is a dict, mapping {'to_table_name': from_table, ...}
"""
-
if engine.name != 'sqlite':
- # Sqlite doesn't support constraints, so nothing to remove.
+ # SQLite doesn't support constraints, so nothing to remove.
remove_constraints(constraints)
for to_table_name in renames:
@@ -141,11 +128,34 @@ def find_migrate_repo(package=None, repo_name='migrate_repo'):
def _sync_common_repo(version):
abs_path = find_migrate_repo()
- init_version = migrate_repo.DB_INIT_VERSION
- engine = sql.get_engine()
- _assert_not_schema_downgrade(version=version)
- migration.db_sync(engine, abs_path, version=version,
- init_version=init_version, sanity_check=False)
+ init_version = get_init_version()
+ with sql.session_for_write() as session:
+ engine = session.get_bind()
+ _assert_not_schema_downgrade(version=version)
+ migration.db_sync(engine, abs_path, version=version,
+ init_version=init_version, sanity_check=False)
+
+
+def get_init_version(abs_path=None):
+ """Get the initial version of a migrate repository
+
+ :param abs_path: Absolute path to migrate repository.
+ :return: initial version number or None, if DB is empty.
+ """
+ if abs_path is None:
+ abs_path = find_migrate_repo()
+
+ repo = migrate.versioning.repository.Repository(abs_path)
+
+ # Sadly, Repository has a `latest` but not an `oldest`.
+ # The value is a VerNum object which needs to be converted into an int.
+ oldest = int(min(repo.versions.versions))
+
+ if oldest < 1:
+ return None
+
+ # The initial version is one less
+ return oldest - 1
def _assert_not_schema_downgrade(extension=None, version=None):
@@ -153,40 +163,46 @@ def _assert_not_schema_downgrade(extension=None, version=None):
try:
current_ver = int(six.text_type(get_db_version(extension)))
if int(version) < current_ver:
- raise migration.exception.DbMigrationError()
- except exceptions.DatabaseNotControlledError:
+ raise migration.exception.DbMigrationError(
+ _("Unable to downgrade schema"))
+ except exceptions.DatabaseNotControlledError: # nosec
# NOTE(morganfainberg): The database is not controlled, this action
# cannot be a downgrade.
pass
def _sync_extension_repo(extension, version):
- init_version = 0
- engine = sql.get_engine()
+ if extension in MIGRATED_EXTENSIONS:
+ raise exception.MigrationMovedFailure(extension=extension)
+
+ with sql.session_for_write() as session:
+ engine = session.get_bind()
- try:
- package_name = '.'.join((contrib.__name__, extension))
- package = importutils.import_module(package_name)
- except ImportError:
- raise ImportError(_("%s extension does not exist.")
- % package_name)
- try:
- abs_path = find_migrate_repo(package)
try:
- migration.db_version_control(sql.get_engine(), abs_path)
- # Register the repo with the version control API
- # If it already knows about the repo, it will throw
- # an exception that we can safely ignore
- except exceptions.DatabaseAlreadyControlledError:
- pass
- except exception.MigrationNotProvided as e:
- print(e)
- sys.exit(1)
+ package_name = '.'.join((contrib.__name__, extension))
+ package = importutils.import_module(package_name)
+ except ImportError:
+ raise ImportError(_("%s extension does not exist.")
+ % package_name)
+ try:
+ abs_path = find_migrate_repo(package)
+ try:
+ migration.db_version_control(engine, abs_path)
+ # Register the repo with the version control API
+ # If it already knows about the repo, it will throw
+ # an exception that we can safely ignore
+ except exceptions.DatabaseAlreadyControlledError: # nosec
+ pass
+ except exception.MigrationNotProvided as e:
+ print(e)
+ sys.exit(1)
+
+ _assert_not_schema_downgrade(extension=extension, version=version)
- _assert_not_schema_downgrade(extension=extension, version=version)
+ init_version = get_init_version(abs_path=abs_path)
- migration.db_sync(engine, abs_path, version=version,
- init_version=init_version, sanity_check=False)
+ migration.db_sync(engine, abs_path, version=version,
+ init_version=init_version, sanity_check=False)
def sync_database_to_version(extension=None, version=None):
@@ -203,8 +219,10 @@ def sync_database_to_version(extension=None, version=None):
def get_db_version(extension=None):
if not extension:
- return migration.db_version(sql.get_engine(), find_migrate_repo(),
- migrate_repo.DB_INIT_VERSION)
+ with sql.session_for_write() as session:
+ return migration.db_version(session.get_bind(),
+ find_migrate_repo(),
+ get_init_version())
try:
package_name = '.'.join((contrib.__name__, extension))
@@ -213,8 +231,9 @@ def get_db_version(extension=None):
raise ImportError(_("%s extension does not exist.")
% package_name)
- return migration.db_version(
- sql.get_engine(), find_migrate_repo(package), 0)
+ with sql.session_for_write() as session:
+ return migration.db_version(
+ session.get_bind(), find_migrate_repo(package), 0)
def print_db_version(extension=None):