From 92fd2dbfb672d7b2b1cdfd5dd5cf89f7716b3e12 Mon Sep 17 00:00:00 2001 From: asteroide Date: Tue, 1 Sep 2015 16:03:26 +0200 Subject: Update Keystone code from official Github repository with branch Master on 09/01/2015. Change-Id: I0ff6099e6e2580f87f502002a998bbfe12673498 --- .../keystone/contrib/federation/backends/sql.py | 71 +++++++++++++++++++--- 1 file changed, 61 insertions(+), 10 deletions(-) (limited to 'keystone-moon/keystone/contrib/federation/backends') diff --git a/keystone-moon/keystone/contrib/federation/backends/sql.py b/keystone-moon/keystone/contrib/federation/backends/sql.py index f2c124d0..ed07c08f 100644 --- a/keystone-moon/keystone/contrib/federation/backends/sql.py +++ b/keystone-moon/keystone/contrib/federation/backends/sql.py @@ -17,6 +17,7 @@ from oslo_serialization import jsonutils from keystone.common import sql from keystone.contrib.federation import core from keystone import exception +from sqlalchemy import orm class FederationProtocolModel(sql.ModelBase, sql.DictBase): @@ -44,13 +45,53 @@ class FederationProtocolModel(sql.ModelBase, sql.DictBase): class IdentityProviderModel(sql.ModelBase, sql.DictBase): __tablename__ = 'identity_provider' - attributes = ['id', 'remote_id', 'enabled', 'description'] - mutable_attributes = frozenset(['description', 'enabled', 'remote_id']) + attributes = ['id', 'enabled', 'description', 'remote_ids'] + mutable_attributes = frozenset(['description', 'enabled', 'remote_ids']) id = sql.Column(sql.String(64), primary_key=True) - remote_id = sql.Column(sql.String(256), nullable=True) enabled = sql.Column(sql.Boolean, nullable=False) description = sql.Column(sql.Text(), nullable=True) + remote_ids = orm.relationship('IdPRemoteIdsModel', + order_by='IdPRemoteIdsModel.remote_id', + cascade='all, delete-orphan') + + @classmethod + def from_dict(cls, dictionary): + new_dictionary = dictionary.copy() + remote_ids_list = new_dictionary.pop('remote_ids', None) + if not remote_ids_list: + remote_ids_list = [] + identity_provider = cls(**new_dictionary) + remote_ids = [] + # NOTE(fmarco76): the remote_ids_list contains only remote ids + # associated with the IdP because of the "relationship" established in + # sqlalchemy and corresponding to the FK in the idp_remote_ids table + for remote in remote_ids_list: + remote_ids.append(IdPRemoteIdsModel(remote_id=remote)) + identity_provider.remote_ids = remote_ids + return identity_provider + + def to_dict(self): + """Return a dictionary with model's attributes.""" + d = dict() + for attr in self.__class__.attributes: + d[attr] = getattr(self, attr) + d['remote_ids'] = [] + for remote in self.remote_ids: + d['remote_ids'].append(remote.remote_id) + return d + + +class IdPRemoteIdsModel(sql.ModelBase, sql.DictBase): + __tablename__ = 'idp_remote_ids' + attributes = ['idp_id', 'remote_id'] + mutable_attributes = frozenset(['idp_id', 'remote_id']) + + idp_id = sql.Column(sql.String(64), + sql.ForeignKey('identity_provider.id', + ondelete='CASCADE')) + remote_id = sql.Column(sql.String(255), + primary_key=True) @classmethod def from_dict(cls, dictionary): @@ -75,6 +116,7 @@ class MappingModel(sql.ModelBase, sql.DictBase): @classmethod def from_dict(cls, dictionary): new_dictionary = dictionary.copy() + new_dictionary['rules'] = jsonutils.dumps(new_dictionary['rules']) return cls(**new_dictionary) def to_dict(self): @@ -82,20 +124,23 @@ class MappingModel(sql.ModelBase, sql.DictBase): d = dict() for attr in self.__class__.attributes: d[attr] = getattr(self, attr) + d['rules'] = jsonutils.loads(d['rules']) return d class ServiceProviderModel(sql.ModelBase, sql.DictBase): __tablename__ = 'service_provider' - attributes = ['auth_url', 'id', 'enabled', 'description', 'sp_url'] + attributes = ['auth_url', 'id', 'enabled', 'description', + 'relay_state_prefix', 'sp_url'] mutable_attributes = frozenset(['auth_url', 'description', 'enabled', - 'sp_url']) + 'relay_state_prefix', 'sp_url']) id = sql.Column(sql.String(64), primary_key=True) enabled = sql.Column(sql.Boolean, nullable=False) description = sql.Column(sql.Text(), nullable=True) auth_url = sql.Column(sql.String(256), nullable=False) sp_url = sql.Column(sql.String(256), nullable=False) + relay_state_prefix = sql.Column(sql.String(256), nullable=False) @classmethod def from_dict(cls, dictionary): @@ -123,6 +168,7 @@ class Federation(core.Driver): def delete_idp(self, idp_id): with sql.transaction() as session: + self._delete_assigned_protocols(session, idp_id) idp_ref = self._get_idp(session, idp_id) session.delete(idp_ref) @@ -133,7 +179,7 @@ class Federation(core.Driver): return idp_ref def _get_idp_from_remote_id(self, session, remote_id): - q = session.query(IdentityProviderModel) + q = session.query(IdPRemoteIdsModel) q = q.filter_by(remote_id=remote_id) try: return q.one() @@ -153,8 +199,8 @@ class Federation(core.Driver): def get_idp_from_remote_id(self, remote_id): with sql.transaction() as session: - idp_ref = self._get_idp_from_remote_id(session, remote_id) - return idp_ref.to_dict() + ref = self._get_idp_from_remote_id(session, remote_id) + return ref.to_dict() def update_idp(self, idp_id, idp): with sql.transaction() as session: @@ -214,6 +260,11 @@ class Federation(core.Driver): key_ref = self._get_protocol(session, idp_id, protocol_id) session.delete(key_ref) + def _delete_assigned_protocols(self, session, idp_id): + query = session.query(FederationProtocolModel) + query = query.filter_by(idp_id=idp_id) + query.delete() + # Mapping CRUD def _get_mapping(self, session, mapping_id): mapping_ref = session.query(MappingModel).get(mapping_id) @@ -225,7 +276,7 @@ class Federation(core.Driver): def create_mapping(self, mapping_id, mapping): ref = {} ref['id'] = mapping_id - ref['rules'] = jsonutils.dumps(mapping.get('rules')) + ref['rules'] = mapping.get('rules') with sql.transaction() as session: mapping_ref = MappingModel.from_dict(ref) session.add(mapping_ref) @@ -250,7 +301,7 @@ class Federation(core.Driver): def update_mapping(self, mapping_id, mapping): ref = {} ref['id'] = mapping_id - ref['rules'] = jsonutils.dumps(mapping.get('rules')) + ref['rules'] = mapping.get('rules') with sql.transaction() as session: mapping_ref = self._get_mapping(session, mapping_id) old_mapping = mapping_ref.to_dict() -- cgit 1.2.3-korg