From 2e7b4f2027a1147ca28301e4f88adf8274b39a1f Mon Sep 17 00:00:00 2001 From: DUVAL Thomas Date: Thu, 9 Jun 2016 09:11:50 +0200 Subject: Update Keystone core to Mitaka. Change-Id: Ia10d6add16f4a9d25d1f42d420661c46332e69db --- keystone-moon/keystone/catalog/backends/sql.py | 429 ++++++++++++++++++------- 1 file changed, 309 insertions(+), 120 deletions(-) (limited to 'keystone-moon/keystone/catalog/backends/sql.py') diff --git a/keystone-moon/keystone/catalog/backends/sql.py b/keystone-moon/keystone/catalog/backends/sql.py index fe69db58..bd92f107 100644 --- a/keystone-moon/keystone/catalog/backends/sql.py +++ b/keystone-moon/keystone/catalog/backends/sql.py @@ -21,8 +21,10 @@ from sqlalchemy.sql import true from keystone import catalog from keystone.catalog import core +from keystone.common import driver_hints from keystone.common import sql from keystone import exception +from keystone.i18n import _ CONF = cfg.CONF @@ -43,13 +45,6 @@ class Region(sql.ModelBase, sql.DictBase): # "left" and "right" and provide support for a nested set # model. parent_region_id = sql.Column(sql.String(255), nullable=True) - - # TODO(jaypipes): I think it's absolutely stupid that every single model - # is required to have an "extra" column because of the - # DictBase in the keystone.common.sql.core module. Forcing - # tables to have pointless columns in the database is just - # bad. Remove all of this extra JSON blob stuff. - # See: https://bugs.launchpad.net/keystone/+bug/1265071 extra = sql.Column(sql.JsonBlob()) endpoints = sqlalchemy.orm.relationship("Endpoint", backref="region") @@ -89,10 +84,10 @@ class Endpoint(sql.ModelBase, sql.DictBase): class Catalog(catalog.CatalogDriverV8): # Regions def list_regions(self, hints): - session = sql.get_session() - regions = session.query(Region) - regions = sql.filter_limit_query(Region, regions, hints) - return [s.to_dict() for s in list(regions)] + with sql.session_for_read() as session: + regions = session.query(Region) + regions = sql.filter_limit_query(Region, regions, hints) + return [s.to_dict() for s in list(regions)] def _get_region(self, session, region_id): ref = session.query(Region).get(region_id) @@ -141,12 +136,11 @@ class Catalog(catalog.CatalogDriverV8): return False def get_region(self, region_id): - session = sql.get_session() - return self._get_region(session, region_id).to_dict() + with sql.session_for_read() as session: + return self._get_region(session, region_id).to_dict() def delete_region(self, region_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = self._get_region(session, region_id) if self._has_endpoints(session, ref, ref): raise exception.RegionDeletionError(region_id=region_id) @@ -155,16 +149,14 @@ class Catalog(catalog.CatalogDriverV8): @sql.handle_conflicts(conflict_type='region') def create_region(self, region_ref): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: self._check_parent_region(session, region_ref) region = Region.from_dict(region_ref) session.add(region) - return region.to_dict() + return region.to_dict() def update_region(self, region_id, region_ref): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: self._check_parent_region(session, region_ref) ref = self._get_region(session, region_id) old_dict = ref.to_dict() @@ -174,15 +166,15 @@ class Catalog(catalog.CatalogDriverV8): for attr in Region.attributes: if attr != 'id': setattr(ref, attr, getattr(new_region, attr)) - return ref.to_dict() + return ref.to_dict() # Services - @sql.truncated + @driver_hints.truncated def list_services(self, hints): - session = sql.get_session() - services = session.query(Service) - services = sql.filter_limit_query(Service, services, hints) - return [s.to_dict() for s in list(services)] + with sql.session_for_read() as session: + services = session.query(Service) + services = sql.filter_limit_query(Service, services, hints) + return [s.to_dict() for s in list(services)] def _get_service(self, session, service_id): ref = session.query(Service).get(service_id) @@ -191,26 +183,23 @@ class Catalog(catalog.CatalogDriverV8): return ref def get_service(self, service_id): - session = sql.get_session() - return self._get_service(session, service_id).to_dict() + with sql.session_for_read() as session: + return self._get_service(session, service_id).to_dict() def delete_service(self, service_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = self._get_service(session, service_id) session.query(Endpoint).filter_by(service_id=service_id).delete() session.delete(ref) def create_service(self, service_id, service_ref): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: service = Service.from_dict(service_ref) session.add(service) - return service.to_dict() + return service.to_dict() def update_service(self, service_id, service_ref): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = self._get_service(session, service_id) old_dict = ref.to_dict() old_dict.update(service_ref) @@ -219,20 +208,17 @@ class Catalog(catalog.CatalogDriverV8): if attr != 'id': setattr(ref, attr, getattr(new_service, attr)) ref.extra = new_service.extra - return ref.to_dict() + return ref.to_dict() # Endpoints def create_endpoint(self, endpoint_id, endpoint_ref): - session = sql.get_session() new_endpoint = Endpoint.from_dict(endpoint_ref) - - with session.begin(): + with sql.session_for_write() as session: session.add(new_endpoint) return new_endpoint.to_dict() def delete_endpoint(self, endpoint_id): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = self._get_endpoint(session, endpoint_id) session.delete(ref) @@ -243,20 +229,18 @@ class Catalog(catalog.CatalogDriverV8): raise exception.EndpointNotFound(endpoint_id=endpoint_id) def get_endpoint(self, endpoint_id): - session = sql.get_session() - return self._get_endpoint(session, endpoint_id).to_dict() + with sql.session_for_read() as session: + return self._get_endpoint(session, endpoint_id).to_dict() - @sql.truncated + @driver_hints.truncated def list_endpoints(self, hints): - session = sql.get_session() - endpoints = session.query(Endpoint) - endpoints = sql.filter_limit_query(Endpoint, endpoints, hints) - return [e.to_dict() for e in list(endpoints)] + with sql.session_for_read() as session: + endpoints = session.query(Endpoint) + endpoints = sql.filter_limit_query(Endpoint, endpoints, hints) + return [e.to_dict() for e in list(endpoints)] def update_endpoint(self, endpoint_id, endpoint_ref): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_endpoint(session, endpoint_id) old_dict = ref.to_dict() old_dict.update(endpoint_ref) @@ -265,7 +249,7 @@ class Catalog(catalog.CatalogDriverV8): if attr != 'id': setattr(ref, attr, getattr(new_endpoint, attr)) ref.extra = new_endpoint.extra - return ref.to_dict() + return ref.to_dict() def get_catalog(self, user_id, tenant_id): """Retrieve and format the V2 service catalog. @@ -287,44 +271,47 @@ class Catalog(catalog.CatalogDriverV8): substitutions.update({'user_id': user_id}) silent_keyerror_failures = [] if tenant_id: - substitutions.update({'tenant_id': tenant_id}) + substitutions.update({ + 'tenant_id': tenant_id, + 'project_id': tenant_id + }) else: - silent_keyerror_failures = ['tenant_id'] - - session = sql.get_session() - endpoints = (session.query(Endpoint). - options(sql.joinedload(Endpoint.service)). - filter(Endpoint.enabled == true()).all()) - - catalog = {} - - for endpoint in endpoints: - if not endpoint.service['enabled']: - continue - try: - formatted_url = core.format_url( - endpoint['url'], substitutions, - silent_keyerror_failures=silent_keyerror_failures) - if formatted_url is not None: - url = formatted_url - else: + silent_keyerror_failures = ['tenant_id', 'project_id', ] + + with sql.session_for_read() as session: + endpoints = (session.query(Endpoint). + options(sql.joinedload(Endpoint.service)). + filter(Endpoint.enabled == true()).all()) + + catalog = {} + + for endpoint in endpoints: + if not endpoint.service['enabled']: continue - except exception.MalformedEndpoint: - continue # this failure is already logged in format_url() - - region = endpoint['region_id'] - service_type = endpoint.service['type'] - default_service = { - 'id': endpoint['id'], - 'name': endpoint.service.extra.get('name', ''), - 'publicURL': '' - } - catalog.setdefault(region, {}) - catalog[region].setdefault(service_type, default_service) - interface_url = '%sURL' % endpoint['interface'] - catalog[region][service_type][interface_url] = url - - return catalog + try: + formatted_url = core.format_url( + endpoint['url'], substitutions, + silent_keyerror_failures=silent_keyerror_failures) + if formatted_url is not None: + url = formatted_url + else: + continue + except exception.MalformedEndpoint: + continue # this failure is already logged in format_url() + + region = endpoint['region_id'] + service_type = endpoint.service['type'] + default_service = { + 'id': endpoint['id'], + 'name': endpoint.service.extra.get('name', ''), + 'publicURL': '' + } + catalog.setdefault(region, {}) + catalog[region].setdefault(service_type, default_service) + interface_url = '%sURL' % endpoint['interface'] + catalog[region][service_type][interface_url] = url + + return catalog def get_v3_catalog(self, user_id, tenant_id): """Retrieve and format the current V3 service catalog. @@ -344,40 +331,242 @@ class Catalog(catalog.CatalogDriverV8): d.update({'user_id': user_id}) silent_keyerror_failures = [] if tenant_id: - d.update({'tenant_id': tenant_id}) + d.update({ + 'tenant_id': tenant_id, + 'project_id': tenant_id, + }) else: - silent_keyerror_failures = ['tenant_id'] - - session = sql.get_session() - services = (session.query(Service).filter(Service.enabled == true()). - options(sql.joinedload(Service.endpoints)). - all()) - - def make_v3_endpoints(endpoints): - for endpoint in (ep.to_dict() for ep in endpoints if ep.enabled): - del endpoint['service_id'] - del endpoint['legacy_endpoint_id'] - del endpoint['enabled'] - endpoint['region'] = endpoint['region_id'] - try: - formatted_url = core.format_url( - endpoint['url'], d, - silent_keyerror_failures=silent_keyerror_failures) - if formatted_url: - endpoint['url'] = formatted_url - else: + silent_keyerror_failures = ['tenant_id', 'project_id', ] + + with sql.session_for_read() as session: + services = (session.query(Service).filter( + Service.enabled == true()).options( + sql.joinedload(Service.endpoints)).all()) + + def make_v3_endpoints(endpoints): + for endpoint in (ep.to_dict() + for ep in endpoints if ep.enabled): + del endpoint['service_id'] + del endpoint['legacy_endpoint_id'] + del endpoint['enabled'] + endpoint['region'] = endpoint['region_id'] + try: + formatted_url = core.format_url( + endpoint['url'], d, + silent_keyerror_failures=silent_keyerror_failures) + if formatted_url: + endpoint['url'] = formatted_url + else: + continue + except exception.MalformedEndpoint: + # this failure is already logged in format_url() continue - except exception.MalformedEndpoint: - continue # this failure is already logged in format_url() - yield endpoint + yield endpoint + + # TODO(davechen): If there is service with no endpoints, we should + # skip the service instead of keeping it in the catalog, + # see bug #1436704. + def make_v3_service(svc): + eps = list(make_v3_endpoints(svc.endpoints)) + service = {'endpoints': eps, 'id': svc.id, 'type': svc.type} + service['name'] = svc.extra.get('name', '') + return service + + return [make_v3_service(svc) for svc in services] + + @sql.handle_conflicts(conflict_type='project_endpoint') + def add_endpoint_to_project(self, endpoint_id, project_id): + with sql.session_for_write() as session: + endpoint_filter_ref = ProjectEndpoint(endpoint_id=endpoint_id, + project_id=project_id) + session.add(endpoint_filter_ref) + + def _get_project_endpoint_ref(self, session, endpoint_id, project_id): + endpoint_filter_ref = session.query(ProjectEndpoint).get( + (endpoint_id, project_id)) + if endpoint_filter_ref is None: + msg = _('Endpoint %(endpoint_id)s not found in project ' + '%(project_id)s') % {'endpoint_id': endpoint_id, + 'project_id': project_id} + raise exception.NotFound(msg) + return endpoint_filter_ref + + def check_endpoint_in_project(self, endpoint_id, project_id): + with sql.session_for_read() as session: + self._get_project_endpoint_ref(session, endpoint_id, project_id) + + def remove_endpoint_from_project(self, endpoint_id, project_id): + with sql.session_for_write() as session: + endpoint_filter_ref = self._get_project_endpoint_ref( + session, endpoint_id, project_id) + session.delete(endpoint_filter_ref) + + def list_endpoints_for_project(self, project_id): + with sql.session_for_read() as session: + query = session.query(ProjectEndpoint) + query = query.filter_by(project_id=project_id) + endpoint_filter_refs = query.all() + return [ref.to_dict() for ref in endpoint_filter_refs] + + def list_projects_for_endpoint(self, endpoint_id): + with sql.session_for_read() as session: + query = session.query(ProjectEndpoint) + query = query.filter_by(endpoint_id=endpoint_id) + endpoint_filter_refs = query.all() + return [ref.to_dict() for ref in endpoint_filter_refs] + + def delete_association_by_endpoint(self, endpoint_id): + with sql.session_for_write() as session: + query = session.query(ProjectEndpoint) + query = query.filter_by(endpoint_id=endpoint_id) + query.delete(synchronize_session=False) + + def delete_association_by_project(self, project_id): + with sql.session_for_write() as session: + query = session.query(ProjectEndpoint) + query = query.filter_by(project_id=project_id) + query.delete(synchronize_session=False) + + def create_endpoint_group(self, endpoint_group_id, endpoint_group): + with sql.session_for_write() as session: + endpoint_group_ref = EndpointGroup.from_dict(endpoint_group) + session.add(endpoint_group_ref) + return endpoint_group_ref.to_dict() + + def _get_endpoint_group(self, session, endpoint_group_id): + endpoint_group_ref = session.query(EndpointGroup).get( + endpoint_group_id) + if endpoint_group_ref is None: + raise exception.EndpointGroupNotFound( + endpoint_group_id=endpoint_group_id) + return endpoint_group_ref + + def get_endpoint_group(self, endpoint_group_id): + with sql.session_for_read() as session: + endpoint_group_ref = self._get_endpoint_group(session, + endpoint_group_id) + return endpoint_group_ref.to_dict() + + def update_endpoint_group(self, endpoint_group_id, endpoint_group): + with sql.session_for_write() as session: + endpoint_group_ref = self._get_endpoint_group(session, + endpoint_group_id) + old_endpoint_group = endpoint_group_ref.to_dict() + old_endpoint_group.update(endpoint_group) + new_endpoint_group = EndpointGroup.from_dict(old_endpoint_group) + for attr in EndpointGroup.mutable_attributes: + setattr(endpoint_group_ref, attr, + getattr(new_endpoint_group, attr)) + return endpoint_group_ref.to_dict() + + def delete_endpoint_group(self, endpoint_group_id): + with sql.session_for_write() as session: + endpoint_group_ref = self._get_endpoint_group(session, + endpoint_group_id) + self._delete_endpoint_group_association_by_endpoint_group( + session, endpoint_group_id) + session.delete(endpoint_group_ref) + + def get_endpoint_group_in_project(self, endpoint_group_id, project_id): + with sql.session_for_read() as session: + ref = self._get_endpoint_group_in_project(session, + endpoint_group_id, + project_id) + return ref.to_dict() + + @sql.handle_conflicts(conflict_type='project_endpoint_group') + def add_endpoint_group_to_project(self, endpoint_group_id, project_id): + with sql.session_for_write() as session: + # Create a new Project Endpoint group entity + endpoint_group_project_ref = ProjectEndpointGroupMembership( + endpoint_group_id=endpoint_group_id, project_id=project_id) + session.add(endpoint_group_project_ref) + + def _get_endpoint_group_in_project(self, session, + endpoint_group_id, project_id): + endpoint_group_project_ref = session.query( + ProjectEndpointGroupMembership).get((endpoint_group_id, + project_id)) + if endpoint_group_project_ref is None: + msg = _('Endpoint Group Project Association not found') + raise exception.NotFound(msg) + else: + return endpoint_group_project_ref + + def list_endpoint_groups(self): + with sql.session_for_read() as session: + query = session.query(EndpointGroup) + endpoint_group_refs = query.all() + return [e.to_dict() for e in endpoint_group_refs] + + def list_endpoint_groups_for_project(self, project_id): + with sql.session_for_read() as session: + query = session.query(ProjectEndpointGroupMembership) + query = query.filter_by(project_id=project_id) + endpoint_group_refs = query.all() + return [ref.to_dict() for ref in endpoint_group_refs] + + def remove_endpoint_group_from_project(self, endpoint_group_id, + project_id): + with sql.session_for_write() as session: + endpoint_group_project_ref = self._get_endpoint_group_in_project( + session, endpoint_group_id, project_id) + session.delete(endpoint_group_project_ref) + + def list_projects_associated_with_endpoint_group(self, endpoint_group_id): + with sql.session_for_read() as session: + query = session.query(ProjectEndpointGroupMembership) + query = query.filter_by(endpoint_group_id=endpoint_group_id) + endpoint_group_refs = query.all() + return [ref.to_dict() for ref in endpoint_group_refs] + + def _delete_endpoint_group_association_by_endpoint_group( + self, session, endpoint_group_id): + query = session.query(ProjectEndpointGroupMembership) + query = query.filter_by(endpoint_group_id=endpoint_group_id) + query.delete() + + def delete_endpoint_group_association_by_project(self, project_id): + with sql.session_for_write() as session: + query = session.query(ProjectEndpointGroupMembership) + query = query.filter_by(project_id=project_id) + query.delete() + + +class ProjectEndpoint(sql.ModelBase, sql.ModelDictMixin): + """project-endpoint relationship table.""" + + __tablename__ = 'project_endpoint' + attributes = ['endpoint_id', 'project_id'] + endpoint_id = sql.Column(sql.String(64), + primary_key=True, + nullable=False) + project_id = sql.Column(sql.String(64), + primary_key=True, + nullable=False) + - # TODO(davechen): If there is service with no endpoints, we should skip - # the service instead of keeping it in the catalog, see bug #1436704. - def make_v3_service(svc): - eps = list(make_v3_endpoints(svc.endpoints)) - service = {'endpoints': eps, 'id': svc.id, 'type': svc.type} - service['name'] = svc.extra.get('name', '') - return service +class EndpointGroup(sql.ModelBase, sql.ModelDictMixin): + """Endpoint Groups table.""" - return [make_v3_service(svc) for svc in services] + __tablename__ = 'endpoint_group' + attributes = ['id', 'name', 'description', 'filters'] + mutable_attributes = frozenset(['name', 'description', 'filters']) + id = sql.Column(sql.String(64), primary_key=True) + name = sql.Column(sql.String(255), nullable=False) + description = sql.Column(sql.Text, nullable=True) + filters = sql.Column(sql.JsonBlob(), nullable=False) + + +class ProjectEndpointGroupMembership(sql.ModelBase, sql.ModelDictMixin): + """Project to Endpoint group relationship table.""" + + __tablename__ = 'project_endpoint_group' + attributes = ['endpoint_group_id', 'project_id'] + endpoint_group_id = sql.Column(sql.String(64), + sql.ForeignKey('endpoint_group.id'), + nullable=False) + project_id = sql.Column(sql.String(64), nullable=False) + __table_args__ = (sql.PrimaryKeyConstraint('endpoint_group_id', + 'project_id'),) -- cgit 1.2.3-korg