diff options
Diffstat (limited to 'keystone-moon/keystone/contrib/federation/backends/sql.py')
-rw-r--r-- | keystone-moon/keystone/contrib/federation/backends/sql.py | 357 |
1 files changed, 10 insertions, 347 deletions
diff --git a/keystone-moon/keystone/contrib/federation/backends/sql.py b/keystone-moon/keystone/contrib/federation/backends/sql.py index dbd17025..3c24d9c0 100644 --- a/keystone-moon/keystone/contrib/federation/backends/sql.py +++ b/keystone-moon/keystone/contrib/federation/backends/sql.py @@ -12,355 +12,18 @@ # License for the specific language governing permissions and limitations # under the License. -from oslo_serialization import jsonutils +from oslo_log import versionutils -from keystone.common import sql -from keystone.contrib.federation import core -from keystone import exception -from sqlalchemy import orm +from keystone.federation.backends import sql +_OLD = "keystone.contrib.federation.backends.sql.Federation" +_NEW = "sql" -class FederationProtocolModel(sql.ModelBase, sql.DictBase): - __tablename__ = 'federation_protocol' - attributes = ['id', 'idp_id', 'mapping_id'] - mutable_attributes = frozenset(['mapping_id']) - id = sql.Column(sql.String(64), primary_key=True) - idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id', - ondelete='CASCADE'), primary_key=True) - mapping_id = sql.Column(sql.String(64), nullable=False) +class Federation(sql.Federation): - @classmethod - def from_dict(cls, dictionary): - new_dictionary = dictionary.copy() - return cls(**new_dictionary) - - def to_dict(self): - """Return a dictionary with model's attributes.""" - d = dict() - for attr in self.__class__.attributes: - d[attr] = getattr(self, attr) - return d - - -class IdentityProviderModel(sql.ModelBase, sql.DictBase): - __tablename__ = 'identity_provider' - attributes = ['id', 'enabled', 'description', 'remote_ids'] - mutable_attributes = frozenset(['description', 'enabled', 'remote_ids']) - - id = sql.Column(sql.String(64), primary_key=True) - enabled = sql.Column(sql.Boolean, nullable=False) - description = sql.Column(sql.Text(), nullable=True) - remote_ids = orm.relationship('IdPRemoteIdsModel', - order_by='IdPRemoteIdsModel.remote_id', - cascade='all, delete-orphan') - - @classmethod - def from_dict(cls, dictionary): - new_dictionary = dictionary.copy() - remote_ids_list = new_dictionary.pop('remote_ids', None) - if not remote_ids_list: - remote_ids_list = [] - identity_provider = cls(**new_dictionary) - remote_ids = [] - # NOTE(fmarco76): the remote_ids_list contains only remote ids - # associated with the IdP because of the "relationship" established in - # sqlalchemy and corresponding to the FK in the idp_remote_ids table - for remote in remote_ids_list: - remote_ids.append(IdPRemoteIdsModel(remote_id=remote)) - identity_provider.remote_ids = remote_ids - return identity_provider - - def to_dict(self): - """Return a dictionary with model's attributes.""" - d = dict() - for attr in self.__class__.attributes: - d[attr] = getattr(self, attr) - d['remote_ids'] = [] - for remote in self.remote_ids: - d['remote_ids'].append(remote.remote_id) - return d - - -class IdPRemoteIdsModel(sql.ModelBase, sql.DictBase): - __tablename__ = 'idp_remote_ids' - attributes = ['idp_id', 'remote_id'] - mutable_attributes = frozenset(['idp_id', 'remote_id']) - - idp_id = sql.Column(sql.String(64), - sql.ForeignKey('identity_provider.id', - ondelete='CASCADE')) - remote_id = sql.Column(sql.String(255), - primary_key=True) - - @classmethod - def from_dict(cls, dictionary): - new_dictionary = dictionary.copy() - return cls(**new_dictionary) - - def to_dict(self): - """Return a dictionary with model's attributes.""" - d = dict() - for attr in self.__class__.attributes: - d[attr] = getattr(self, attr) - return d - - -class MappingModel(sql.ModelBase, sql.DictBase): - __tablename__ = 'mapping' - attributes = ['id', 'rules'] - - id = sql.Column(sql.String(64), primary_key=True) - rules = sql.Column(sql.JsonBlob(), nullable=False) - - @classmethod - def from_dict(cls, dictionary): - new_dictionary = dictionary.copy() - new_dictionary['rules'] = jsonutils.dumps(new_dictionary['rules']) - return cls(**new_dictionary) - - def to_dict(self): - """Return a dictionary with model's attributes.""" - d = dict() - for attr in self.__class__.attributes: - d[attr] = getattr(self, attr) - d['rules'] = jsonutils.loads(d['rules']) - return d - - -class ServiceProviderModel(sql.ModelBase, sql.DictBase): - __tablename__ = 'service_provider' - attributes = ['auth_url', 'id', 'enabled', 'description', - 'relay_state_prefix', 'sp_url'] - mutable_attributes = frozenset(['auth_url', 'description', 'enabled', - 'relay_state_prefix', 'sp_url']) - - id = sql.Column(sql.String(64), primary_key=True) - enabled = sql.Column(sql.Boolean, nullable=False) - description = sql.Column(sql.Text(), nullable=True) - auth_url = sql.Column(sql.String(256), nullable=False) - sp_url = sql.Column(sql.String(256), nullable=False) - relay_state_prefix = sql.Column(sql.String(256), nullable=False) - - @classmethod - def from_dict(cls, dictionary): - new_dictionary = dictionary.copy() - return cls(**new_dictionary) - - def to_dict(self): - """Return a dictionary with model's attributes.""" - d = dict() - for attr in self.__class__.attributes: - d[attr] = getattr(self, attr) - return d - - -class Federation(core.FederationDriverV8): - - # Identity Provider CRUD - @sql.handle_conflicts(conflict_type='identity_provider') - def create_idp(self, idp_id, idp): - idp['id'] = idp_id - with sql.transaction() as session: - idp_ref = IdentityProviderModel.from_dict(idp) - session.add(idp_ref) - return idp_ref.to_dict() - - def delete_idp(self, idp_id): - with sql.transaction() as session: - self._delete_assigned_protocols(session, idp_id) - idp_ref = self._get_idp(session, idp_id) - session.delete(idp_ref) - - def _get_idp(self, session, idp_id): - idp_ref = session.query(IdentityProviderModel).get(idp_id) - if not idp_ref: - raise exception.IdentityProviderNotFound(idp_id=idp_id) - return idp_ref - - def _get_idp_from_remote_id(self, session, remote_id): - q = session.query(IdPRemoteIdsModel) - q = q.filter_by(remote_id=remote_id) - try: - return q.one() - except sql.NotFound: - raise exception.IdentityProviderNotFound(idp_id=remote_id) - - def list_idps(self): - with sql.transaction() as session: - idps = session.query(IdentityProviderModel) - idps_list = [idp.to_dict() for idp in idps] - return idps_list - - def get_idp(self, idp_id): - with sql.transaction() as session: - idp_ref = self._get_idp(session, idp_id) - return idp_ref.to_dict() - - def get_idp_from_remote_id(self, remote_id): - with sql.transaction() as session: - ref = self._get_idp_from_remote_id(session, remote_id) - return ref.to_dict() - - def update_idp(self, idp_id, idp): - with sql.transaction() as session: - idp_ref = self._get_idp(session, idp_id) - old_idp = idp_ref.to_dict() - old_idp.update(idp) - new_idp = IdentityProviderModel.from_dict(old_idp) - for attr in IdentityProviderModel.mutable_attributes: - setattr(idp_ref, attr, getattr(new_idp, attr)) - return idp_ref.to_dict() - - # Protocol CRUD - def _get_protocol(self, session, idp_id, protocol_id): - q = session.query(FederationProtocolModel) - q = q.filter_by(id=protocol_id, idp_id=idp_id) - try: - return q.one() - except sql.NotFound: - kwargs = {'protocol_id': protocol_id, - 'idp_id': idp_id} - raise exception.FederatedProtocolNotFound(**kwargs) - - @sql.handle_conflicts(conflict_type='federation_protocol') - def create_protocol(self, idp_id, protocol_id, protocol): - protocol['id'] = protocol_id - protocol['idp_id'] = idp_id - with sql.transaction() as session: - self._get_idp(session, idp_id) - protocol_ref = FederationProtocolModel.from_dict(protocol) - session.add(protocol_ref) - return protocol_ref.to_dict() - - def update_protocol(self, idp_id, protocol_id, protocol): - with sql.transaction() as session: - proto_ref = self._get_protocol(session, idp_id, protocol_id) - old_proto = proto_ref.to_dict() - old_proto.update(protocol) - new_proto = FederationProtocolModel.from_dict(old_proto) - for attr in FederationProtocolModel.mutable_attributes: - setattr(proto_ref, attr, getattr(new_proto, attr)) - return proto_ref.to_dict() - - def get_protocol(self, idp_id, protocol_id): - with sql.transaction() as session: - protocol_ref = self._get_protocol(session, idp_id, protocol_id) - return protocol_ref.to_dict() - - def list_protocols(self, idp_id): - with sql.transaction() as session: - q = session.query(FederationProtocolModel) - q = q.filter_by(idp_id=idp_id) - protocols = [protocol.to_dict() for protocol in q] - return protocols - - def delete_protocol(self, idp_id, protocol_id): - with sql.transaction() as session: - key_ref = self._get_protocol(session, idp_id, protocol_id) - session.delete(key_ref) - - def _delete_assigned_protocols(self, session, idp_id): - query = session.query(FederationProtocolModel) - query = query.filter_by(idp_id=idp_id) - query.delete() - - # Mapping CRUD - def _get_mapping(self, session, mapping_id): - mapping_ref = session.query(MappingModel).get(mapping_id) - if not mapping_ref: - raise exception.MappingNotFound(mapping_id=mapping_id) - return mapping_ref - - @sql.handle_conflicts(conflict_type='mapping') - def create_mapping(self, mapping_id, mapping): - ref = {} - ref['id'] = mapping_id - ref['rules'] = mapping.get('rules') - with sql.transaction() as session: - mapping_ref = MappingModel.from_dict(ref) - session.add(mapping_ref) - return mapping_ref.to_dict() - - def delete_mapping(self, mapping_id): - with sql.transaction() as session: - mapping_ref = self._get_mapping(session, mapping_id) - session.delete(mapping_ref) - - def list_mappings(self): - with sql.transaction() as session: - mappings = session.query(MappingModel) - return [x.to_dict() for x in mappings] - - def get_mapping(self, mapping_id): - with sql.transaction() as session: - mapping_ref = self._get_mapping(session, mapping_id) - return mapping_ref.to_dict() - - @sql.handle_conflicts(conflict_type='mapping') - def update_mapping(self, mapping_id, mapping): - ref = {} - ref['id'] = mapping_id - ref['rules'] = mapping.get('rules') - with sql.transaction() as session: - mapping_ref = self._get_mapping(session, mapping_id) - old_mapping = mapping_ref.to_dict() - old_mapping.update(ref) - new_mapping = MappingModel.from_dict(old_mapping) - for attr in MappingModel.attributes: - setattr(mapping_ref, attr, getattr(new_mapping, attr)) - return mapping_ref.to_dict() - - def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id): - with sql.transaction() as session: - protocol_ref = self._get_protocol(session, idp_id, protocol_id) - mapping_id = protocol_ref.mapping_id - mapping_ref = self._get_mapping(session, mapping_id) - return mapping_ref.to_dict() - - # Service Provider CRUD - @sql.handle_conflicts(conflict_type='service_provider') - def create_sp(self, sp_id, sp): - sp['id'] = sp_id - with sql.transaction() as session: - sp_ref = ServiceProviderModel.from_dict(sp) - session.add(sp_ref) - return sp_ref.to_dict() - - def delete_sp(self, sp_id): - with sql.transaction() as session: - sp_ref = self._get_sp(session, sp_id) - session.delete(sp_ref) - - def _get_sp(self, session, sp_id): - sp_ref = session.query(ServiceProviderModel).get(sp_id) - if not sp_ref: - raise exception.ServiceProviderNotFound(sp_id=sp_id) - return sp_ref - - def list_sps(self): - with sql.transaction() as session: - sps = session.query(ServiceProviderModel) - sps_list = [sp.to_dict() for sp in sps] - return sps_list - - def get_sp(self, sp_id): - with sql.transaction() as session: - sp_ref = self._get_sp(session, sp_id) - return sp_ref.to_dict() - - def update_sp(self, sp_id, sp): - with sql.transaction() as session: - sp_ref = self._get_sp(session, sp_id) - old_sp = sp_ref.to_dict() - old_sp.update(sp) - new_sp = ServiceProviderModel.from_dict(old_sp) - for attr in ServiceProviderModel.mutable_attributes: - setattr(sp_ref, attr, getattr(new_sp, attr)) - return sp_ref.to_dict() - - def get_enabled_service_providers(self): - with sql.transaction() as session: - service_providers = session.query(ServiceProviderModel) - service_providers = service_providers.filter_by(enabled=True) - return service_providers + @versionutils.deprecated(versionutils.deprecated.MITAKA, + in_favor_of=_NEW, + what=_OLD) + def __init__(self, *args, **kwargs): + super(Federation, self).__init__(*args, **kwargs) |