diff options
author | bobzhou <zhoubo8@lenovo.com> | 2018-12-17 17:41:39 +0800 |
---|---|---|
committer | bobzhou <zhoubo8@lenovo.com> | 2018-12-17 17:41:39 +0800 |
commit | 31560e9a26e6a7e2e65d18924c938ee7a7683fb3 (patch) | |
tree | 26ec12331a2ec612a0f51c7866fbe41262ca0fe1 /cyborg_enhancement/mitaka_version/cyborg/cyborg/db/sqlalchemy/api.py | |
parent | fc1b7908a77b5cce763f92b3d314057345f3ab7a (diff) |
upload cyborg base mitaka version
Change-Id: Iff52024026c9eff899246be70cc89fe6e65befff
Signed-off-by: bobzhou <zhoubo8@lenovo.com>
Diffstat (limited to 'cyborg_enhancement/mitaka_version/cyborg/cyborg/db/sqlalchemy/api.py')
-rw-r--r-- | cyborg_enhancement/mitaka_version/cyborg/cyborg/db/sqlalchemy/api.py | 513 |
1 files changed, 513 insertions, 0 deletions
diff --git a/cyborg_enhancement/mitaka_version/cyborg/cyborg/db/sqlalchemy/api.py b/cyborg_enhancement/mitaka_version/cyborg/cyborg/db/sqlalchemy/api.py new file mode 100644 index 0000000..22233fb --- /dev/null +++ b/cyborg_enhancement/mitaka_version/cyborg/cyborg/db/sqlalchemy/api.py @@ -0,0 +1,513 @@ +# Copyright 2017 Huawei Technologies Co.,LTD. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""SQLAlchemy storage backend.""" + +import threading +import copy +from oslo_db import api as oslo_db_api +from oslo_db import exception as db_exc +from oslo_db.sqlalchemy import enginefacade +from oslo_db.sqlalchemy import utils as sqlalchemyutils +from oslo_log import log +from oslo_utils import strutils +from oslo_utils import uuidutils +from sqlalchemy.orm.exc import NoResultFound + +from cyborg.common import exception +from cyborg.db.sqlalchemy import models +from cyborg.common.i18n import _ +from cyborg.db import api + + +_CONTEXT = threading.local() +LOG = log.getLogger(__name__) + + +def get_backend(): + """The backend is this module itself.""" + return Connection() + + +def _session_for_read(): + return enginefacade.reader.using(_CONTEXT) + + +def _session_for_write(): + return enginefacade.writer.using(_CONTEXT) + + +def model_query(context, model, *args, **kwargs): + """Query helper for simpler session usage. + + :param context: Context of the query + :param model: Model to query. Must be a subclass of ModelBase. + :param args: Arguments to query. If None - model is used. + + Keyword arguments: + + :keyword project_only: + If set to True, then will do query filter with context's project_id. + if set to False or absent, then will not do query filter with context's + project_id. + :type project_only: bool + """ + + if kwargs.pop("project_only", False): + kwargs["project_id"] = context.tenant + + with _session_for_read() as session: + query = sqlalchemyutils.model_query( + model, session, args, **kwargs) + return query + + +def add_identity_filter(query, value): + """Adds an identity filter to a query. + + Filters results by ID, if supplied value is a valid integer. + Otherwise attempts to filter results by UUID. + + :param query: Initial query to add filter to. + :param value: Value for filtering results by. + :return: Modified query. + """ + if strutils.is_int_like(value): + return query.filter_by(id=value) + elif uuidutils.is_uuid_like(value): + return query.filter_by(uuid=value) + else: + raise exception.InvalidIdentity(identity=value) + + +def _paginate_query(context, model, limit, marker, sort_key, sort_dir, query): + sort_keys = ['id'] + if sort_key and sort_key not in sort_keys: + sort_keys.insert(0, sort_key) + try: + query = sqlalchemyutils.paginate_query(query, model, limit, sort_keys, + marker=marker, + sort_dir=sort_dir) + except db_exc.InvalidSortKey: + raise exception.InvalidParameterValue( + _('The sort_key value "%(key)s" is an invalid field for sorting') + % {'key': sort_key}) + return query.all() + + +class Connection(api.Connection): + """SqlAlchemy connection.""" + + def __init__(self): + pass + + def accelerator_create(self, context, values): + if not values.get('uuid'): + values['uuid'] = uuidutils.generate_uuid() + + accelerator = models.Accelerator() + accelerator.update(values) + + with _session_for_write() as session: + try: + session.add(accelerator) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.AcceleratorAlreadyExists(uuid=values['uuid']) + return accelerator + + def accelerator_get(self, context, uuid): + query = model_query(context, models.Accelerator).filter_by(uuid=uuid) + try: + return query.one() + except NoResultFound: + raise exception.AcceleratorNotFound(uuid=uuid) + + def accelerator_list(self, context, limit, marker, sort_key, sort_dir, + project_only): + query = model_query(context, models.Accelerator, + project_only = project_only) + + return _paginate_query(context, models.Accelerator,limit,marker, + sort_key, sort_dir, query) + + def accelerator_update(self, context, uuid, values): + if 'uuid' in values: + msg = _("Cannot overwrite UUID for existing Accelerator.") + raise exception.InvalidParameterValue(err = msg) + + try: + return self._do_update_accelerator(context, uuid, values) + except db_exc.DBDuplicateEntry as e: + if 'name' in e.columns: + raise exception.DuplicateName(name=values['name']) + + + @oslo_db_api.retry_on_deadlock + def _do_update_accelerator(self, context, uuid, values): + with _session_for_write(): + query = model_query(context, models.Port) + query = add_identity_filter(query, uuid) + try: + ref = query.with_lockmode('update').one() + except NoResultFound: + raise exception.PortNotFound(uuid=uuid) + + ref.update(values) + return ref + + @oslo_db_api.retry_on_deadlock + def accelerator_destory(self, context, uuid): + with _session_for_write(): + query = model_query(context, models.Accelerator) + query = add_identity_filter(query, uuid) + count = query.delete() + if count != 1: + raise exception.AcceleratorNotFound(uuid=uuid) + + + + def port_create(self, context, values): + if not values.get('uuid'): + values['uuid'] = uuidutils.generate_uuid() + if not values.get('is_used'): + values['is_used'] = 0 + + port = models.Port() + port.update(values) + + with _session_for_write() as session: + try: + session.add(port) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.PortAlreadyExists(uuid=values['uuid']) + return port + + def port_get(self, context, uuid): + query = model_query(context, models.Port).filter_by(uuid=uuid) + try: + return query.one() + except NoResultFound: + raise exception.PortNotFound(uuid=uuid) + + def port_get(self, context, computer_node, phy_port_name, pci_slot): + query = model_query(context, models.Port).filter_by(computer_node=computer_node).\ + filter_by(phy_port_name=phy_port_name).filter_by(pci_slot=pci_slot) + try: + return query.one() + except NoResultFound: + return None + + def port_list(self, context, limit, marker, sort_key, sort_dir): + query = model_query(context, models.Port) + + return _paginate_query(context, models.Port, limit, marker, + sort_key, sort_dir, query) + + def port_update(self, context, uuid, values): + if 'uuid' in values: + msg = _("Cannot overwrite UUID for existing Port.") + raise exception.InvalidParameterValue(err=msg) + + try: + return self._do_update_port(context, uuid, values) + except db_exc.DBDuplicateEntry as e: + if 'name' in e.columns: + raise exception.PortDuplicateName(name=values['name']) + + @oslo_db_api.retry_on_deadlock + def _do_update_port(self, context, uuid, values): + with _session_for_write(): + query = model_query(context, models.Port) + query = add_identity_filter(query, uuid) + try: + ref = query.with_lockmode('update').one() + except NoResultFound: + raise exception.PortNotFound(uuid=uuid) + ref.update(values) + return ref + + @oslo_db_api.retry_on_deadlock + def port_destory(self, context, uuid): + with _session_for_write(): + query = model_query(context, models.Port) + query = add_identity_filter(query, uuid) + count = query.delete() + if count == 0: + raise exception.PortNotFound(uuid=uuid) + + + #deployables table operations. + def deployable_create(self, context, values): + if not values.get('uuid'): + values['uuid'] = uuidutils.generate_uuid() + + if values.get('id'): + values.pop('id', None) + + deployable = models.Deployable() + deployable.update(values) + + with _session_for_write() as session: + try: + session.add(deployable) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.DeployableAlreadyExists(uuid=values['uuid']) + return deployable + + def deployable_get(self, context, uuid): + query = model_query( + context, + models.Deployable).filter_by(uuid=uuid) + try: + return query.one() + except NoResultFound: + raise exception.DeployableNotFound(uuid=uuid) + + def deployable_get_by_host(self, context, host): + query = model_query( + context, + models.Deployable).filter_by(host=host) + return query.all() + + def deployable_list(self, context): + query = model_query(context, models.Deployable) + return query.all() + + def deployable_update(self, context, uuid, values): + if 'uuid' in values: + msg = _("Cannot overwrite UUID for an existing Deployable.") + raise exception.InvalidParameterValue(err=msg) + + try: + return self._do_update_deployable(context, uuid, values) + except db_exc.DBDuplicateEntry as e: + if 'name' in e.columns: + raise exception.DuplicateDeployableName(name=values['name']) + + @oslo_db_api.retry_on_deadlock + def _do_update_deployable(self, context, uuid, values): + with _session_for_write(): + query = model_query(context, models.Deployable) + #query = add_identity_filter(query, uuid) + query = query.filter_by(uuid=uuid) + try: + ref = query.with_lockmode('update').one() + except NoResultFound: + raise exception.DeployableNotFound(uuid=uuid) + + ref.update(values) + return ref + + @oslo_db_api.retry_on_deadlock + def deployable_delete(self, context, uuid): + with _session_for_write(): + query = model_query(context, models.Deployable) + query = add_identity_filter(query, uuid) + query.update({'root_uuid': None}) + count = query.delete() + if count != 1: + raise exception.DeployableNotFound(uuid=uuid) + + def deployable_get_by_filters(self, context, + filters, sort_key='created_at', + sort_dir='desc', limit=None, + marker=None, join_columns=None): + """Return list of deployables matching all filters sorted by + the sort_key. See deployable_get_by_filters_sort for + more information. + """ + return self.deployable_get_by_filters_sort(context, filters, + limit=limit, marker=marker, + join_columns=join_columns, + sort_keys=[sort_key], + sort_dirs=[sort_dir]) + + def _exact_deployable_filter(self, query, filters, legal_keys): + """Applies exact match filtering to a deployable query. + Returns the updated query. Modifies filters argument to remove + filters consumed. + :param query: query to apply filters to + :param filters: dictionary of filters; values that are lists, + tuples, sets, or frozensets cause an 'IN' test to + be performed, while exact matching ('==' operator) + is used for other values + :param legal_keys: list of keys to apply exact filtering to + """ + + filter_dict = {} + model = models.Deployable + # Walk through all the keys + for key in legal_keys: + # Skip ones we're not filtering on + if key not in filters: + continue + + # OK, filtering on this key; what value do we search for? + value = filters.pop(key) + + if isinstance(value, (list, tuple, set, frozenset)): + if not value: + return None + # Looking for values in a list; apply to query directly + column_attr = getattr(model, key) + query = query.filter(column_attr.in_(value)) + else: + filter_dict[key] = value + # Apply simple exact matches + if filter_dict: + query = query.filter(*[getattr(models.Deployable, k) == v + for k, v in filter_dict.items()]) + return query + + def deployable_get_by_filters_sort(self, context, filters, limit=None, + marker=None, join_columns=None, + sort_keys=None, sort_dirs=None): + """Return deployables that match all filters sorted by the given + keys. Deleted deployables will be returned by default, unless + there's a filter that says otherwise. + """ + + if limit == 0: + return [] + + sort_keys, sort_dirs = self.process_sort_params(sort_keys, + sort_dirs, + default_dir='desc') + query_prefix = model_query(context, models.Deployable) + filters = copy.deepcopy(filters) + + exact_match_filter_names = ['uuid', 'name', + 'parent_uuid', 'root_uuid', + 'pcie_address', 'host', + 'board', 'vendor', 'version', + 'type', 'assignable', 'instance_uuid', + 'availability', 'accelerator_id'] + + # Filter the query + query_prefix = self._exact_deployable_filter(query_prefix, + filters, + exact_match_filter_names) + if query_prefix is None: + return [] + deployables = query_prefix.all() + return deployables + + def attribute_create(self, context, key, value): + update_fields = {'key': key, 'value': value} + update_fields['uuid'] = uuidutils.generate_uuid() + + attribute = models.Attribute() + attribute.update(update_fields) + + with _session_for_write() as session: + try: + session.add(attribute) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.AttributeAlreadyExists( + uuid=update_fields['uuid']) + return attribute + + def attribute_get(self, context, uuid): + query = model_query( + context, + models.Attribute).filter_by(uuid=uuid) + try: + return query.one() + except NoResultFound: + raise exception.AttributeNotFound(uuid=uuid) + + def attribute_get_by_deployable_uuid(self, context, deployable_uuid): + query = model_query( + context, + models.Attribute).filter_by(deployable_uuid=deployable_uuid) + try: + return query.all() + except NoResultFound: + raise exception.AttributeNotFound(uuid=uuid) + + def attribute_update(self, context, uuid, key, value): + return self._do_update_attribute(context, uuid, key, value) + + @oslo_db_api.retry_on_deadlock + def _do_update_attribute(self, context, uuid, key, value): + update_fields = {'key': key, 'value': value} + with _session_for_write(): + query = model_query(context, models.Attribute) + query = add_identity_filter(query, uuid) + try: + ref = query.with_lockmode('update').one() + except NoResultFound: + raise exception.AttributeNotFound(uuid=uuid) + + ref.update(update_fields) + return ref + + def attribute_delete(self, context, uuid): + with _session_for_write(): + query = model_query(context, models.Attribute) + query = add_identity_filter(query, uuid) + count = query.delete() + if count != 1: + raise exception.AttributeNotFound(uuid=uuid) + + def process_sort_params(self, sort_keys, sort_dirs, + default_keys=['created_at', 'id'], + default_dir='asc'): + + # Determine direction to use for when adding default keys + if sort_dirs and len(sort_dirs) != 0: + default_dir_value = sort_dirs[0] + else: + default_dir_value = default_dir + + # Create list of keys (do not modify the input list) + if sort_keys: + result_keys = list(sort_keys) + else: + result_keys = [] + + # If a list of directions is not provided, + # use the default sort direction for all provided keys + if sort_dirs: + result_dirs = [] + # Verify sort direction + for sort_dir in sort_dirs: + if sort_dir not in ('asc', 'desc'): + msg = _("Unknown sort direction, must be 'desc' or 'asc'") + raise exception.InvalidInput(reason=msg) + result_dirs.append(sort_dir) + else: + result_dirs = [default_dir_value for _sort_key in result_keys] + + # Ensure that the key and direction length match + while len(result_dirs) < len(result_keys): + result_dirs.append(default_dir_value) + # Unless more direction are specified, which is an error + if len(result_dirs) > len(result_keys): + msg = _("Sort direction size exceeds sort key size") + raise exception.InvalidInput(reason=msg) + + # Ensure defaults are included + for key in default_keys: + if key not in result_keys: + result_keys.append(key) + result_dirs.append(default_dir_value) + + return result_keys, result_dirs + |