summaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/contrib/federation/backends
diff options
context:
space:
mode:
Diffstat (limited to 'keystone-moon/keystone/contrib/federation/backends')
-rw-r--r--keystone-moon/keystone/contrib/federation/backends/__init__.py0
-rw-r--r--keystone-moon/keystone/contrib/federation/backends/sql.py315
2 files changed, 315 insertions, 0 deletions
diff --git a/keystone-moon/keystone/contrib/federation/backends/__init__.py b/keystone-moon/keystone/contrib/federation/backends/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/keystone-moon/keystone/contrib/federation/backends/__init__.py
diff --git a/keystone-moon/keystone/contrib/federation/backends/sql.py b/keystone-moon/keystone/contrib/federation/backends/sql.py
new file mode 100644
index 00000000..f2c124d0
--- /dev/null
+++ b/keystone-moon/keystone/contrib/federation/backends/sql.py
@@ -0,0 +1,315 @@
+# Copyright 2014 OpenStack Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from oslo_serialization import jsonutils
+
+from keystone.common import sql
+from keystone.contrib.federation import core
+from keystone import exception
+
+
+class FederationProtocolModel(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'federation_protocol'
+ attributes = ['id', 'idp_id', 'mapping_id']
+ mutable_attributes = frozenset(['mapping_id'])
+
+ id = sql.Column(sql.String(64), primary_key=True)
+ idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id',
+ ondelete='CASCADE'), primary_key=True)
+ mapping_id = sql.Column(sql.String(64), nullable=False)
+
+ @classmethod
+ def from_dict(cls, dictionary):
+ new_dictionary = dictionary.copy()
+ return cls(**new_dictionary)
+
+ def to_dict(self):
+ """Return a dictionary with model's attributes."""
+ d = dict()
+ for attr in self.__class__.attributes:
+ d[attr] = getattr(self, attr)
+ return d
+
+
+class IdentityProviderModel(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'identity_provider'
+ attributes = ['id', 'remote_id', 'enabled', 'description']
+ mutable_attributes = frozenset(['description', 'enabled', 'remote_id'])
+
+ id = sql.Column(sql.String(64), primary_key=True)
+ remote_id = sql.Column(sql.String(256), nullable=True)
+ enabled = sql.Column(sql.Boolean, nullable=False)
+ description = sql.Column(sql.Text(), nullable=True)
+
+ @classmethod
+ def from_dict(cls, dictionary):
+ new_dictionary = dictionary.copy()
+ return cls(**new_dictionary)
+
+ def to_dict(self):
+ """Return a dictionary with model's attributes."""
+ d = dict()
+ for attr in self.__class__.attributes:
+ d[attr] = getattr(self, attr)
+ return d
+
+
+class MappingModel(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'mapping'
+ attributes = ['id', 'rules']
+
+ id = sql.Column(sql.String(64), primary_key=True)
+ rules = sql.Column(sql.JsonBlob(), nullable=False)
+
+ @classmethod
+ def from_dict(cls, dictionary):
+ new_dictionary = dictionary.copy()
+ return cls(**new_dictionary)
+
+ def to_dict(self):
+ """Return a dictionary with model's attributes."""
+ d = dict()
+ for attr in self.__class__.attributes:
+ d[attr] = getattr(self, attr)
+ return d
+
+
+class ServiceProviderModel(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'service_provider'
+ attributes = ['auth_url', 'id', 'enabled', 'description', 'sp_url']
+ mutable_attributes = frozenset(['auth_url', 'description', 'enabled',
+ 'sp_url'])
+
+ id = sql.Column(sql.String(64), primary_key=True)
+ enabled = sql.Column(sql.Boolean, nullable=False)
+ description = sql.Column(sql.Text(), nullable=True)
+ auth_url = sql.Column(sql.String(256), nullable=False)
+ sp_url = sql.Column(sql.String(256), nullable=False)
+
+ @classmethod
+ def from_dict(cls, dictionary):
+ new_dictionary = dictionary.copy()
+ return cls(**new_dictionary)
+
+ def to_dict(self):
+ """Return a dictionary with model's attributes."""
+ d = dict()
+ for attr in self.__class__.attributes:
+ d[attr] = getattr(self, attr)
+ return d
+
+
+class Federation(core.Driver):
+
+ # Identity Provider CRUD
+ @sql.handle_conflicts(conflict_type='identity_provider')
+ def create_idp(self, idp_id, idp):
+ idp['id'] = idp_id
+ with sql.transaction() as session:
+ idp_ref = IdentityProviderModel.from_dict(idp)
+ session.add(idp_ref)
+ return idp_ref.to_dict()
+
+ def delete_idp(self, idp_id):
+ with sql.transaction() as session:
+ idp_ref = self._get_idp(session, idp_id)
+ session.delete(idp_ref)
+
+ def _get_idp(self, session, idp_id):
+ idp_ref = session.query(IdentityProviderModel).get(idp_id)
+ if not idp_ref:
+ raise exception.IdentityProviderNotFound(idp_id=idp_id)
+ return idp_ref
+
+ def _get_idp_from_remote_id(self, session, remote_id):
+ q = session.query(IdentityProviderModel)
+ q = q.filter_by(remote_id=remote_id)
+ try:
+ return q.one()
+ except sql.NotFound:
+ raise exception.IdentityProviderNotFound(idp_id=remote_id)
+
+ def list_idps(self):
+ with sql.transaction() as session:
+ idps = session.query(IdentityProviderModel)
+ idps_list = [idp.to_dict() for idp in idps]
+ return idps_list
+
+ def get_idp(self, idp_id):
+ with sql.transaction() as session:
+ idp_ref = self._get_idp(session, idp_id)
+ return idp_ref.to_dict()
+
+ def get_idp_from_remote_id(self, remote_id):
+ with sql.transaction() as session:
+ idp_ref = self._get_idp_from_remote_id(session, remote_id)
+ return idp_ref.to_dict()
+
+ def update_idp(self, idp_id, idp):
+ with sql.transaction() as session:
+ idp_ref = self._get_idp(session, idp_id)
+ old_idp = idp_ref.to_dict()
+ old_idp.update(idp)
+ new_idp = IdentityProviderModel.from_dict(old_idp)
+ for attr in IdentityProviderModel.mutable_attributes:
+ setattr(idp_ref, attr, getattr(new_idp, attr))
+ return idp_ref.to_dict()
+
+ # Protocol CRUD
+ def _get_protocol(self, session, idp_id, protocol_id):
+ q = session.query(FederationProtocolModel)
+ q = q.filter_by(id=protocol_id, idp_id=idp_id)
+ try:
+ return q.one()
+ except sql.NotFound:
+ kwargs = {'protocol_id': protocol_id,
+ 'idp_id': idp_id}
+ raise exception.FederatedProtocolNotFound(**kwargs)
+
+ @sql.handle_conflicts(conflict_type='federation_protocol')
+ def create_protocol(self, idp_id, protocol_id, protocol):
+ protocol['id'] = protocol_id
+ protocol['idp_id'] = idp_id
+ with sql.transaction() as session:
+ self._get_idp(session, idp_id)
+ protocol_ref = FederationProtocolModel.from_dict(protocol)
+ session.add(protocol_ref)
+ return protocol_ref.to_dict()
+
+ def update_protocol(self, idp_id, protocol_id, protocol):
+ with sql.transaction() as session:
+ proto_ref = self._get_protocol(session, idp_id, protocol_id)
+ old_proto = proto_ref.to_dict()
+ old_proto.update(protocol)
+ new_proto = FederationProtocolModel.from_dict(old_proto)
+ for attr in FederationProtocolModel.mutable_attributes:
+ setattr(proto_ref, attr, getattr(new_proto, attr))
+ return proto_ref.to_dict()
+
+ def get_protocol(self, idp_id, protocol_id):
+ with sql.transaction() as session:
+ protocol_ref = self._get_protocol(session, idp_id, protocol_id)
+ return protocol_ref.to_dict()
+
+ def list_protocols(self, idp_id):
+ with sql.transaction() as session:
+ q = session.query(FederationProtocolModel)
+ q = q.filter_by(idp_id=idp_id)
+ protocols = [protocol.to_dict() for protocol in q]
+ return protocols
+
+ def delete_protocol(self, idp_id, protocol_id):
+ with sql.transaction() as session:
+ key_ref = self._get_protocol(session, idp_id, protocol_id)
+ session.delete(key_ref)
+
+ # Mapping CRUD
+ def _get_mapping(self, session, mapping_id):
+ mapping_ref = session.query(MappingModel).get(mapping_id)
+ if not mapping_ref:
+ raise exception.MappingNotFound(mapping_id=mapping_id)
+ return mapping_ref
+
+ @sql.handle_conflicts(conflict_type='mapping')
+ def create_mapping(self, mapping_id, mapping):
+ ref = {}
+ ref['id'] = mapping_id
+ ref['rules'] = jsonutils.dumps(mapping.get('rules'))
+ with sql.transaction() as session:
+ mapping_ref = MappingModel.from_dict(ref)
+ session.add(mapping_ref)
+ return mapping_ref.to_dict()
+
+ def delete_mapping(self, mapping_id):
+ with sql.transaction() as session:
+ mapping_ref = self._get_mapping(session, mapping_id)
+ session.delete(mapping_ref)
+
+ def list_mappings(self):
+ with sql.transaction() as session:
+ mappings = session.query(MappingModel)
+ return [x.to_dict() for x in mappings]
+
+ def get_mapping(self, mapping_id):
+ with sql.transaction() as session:
+ mapping_ref = self._get_mapping(session, mapping_id)
+ return mapping_ref.to_dict()
+
+ @sql.handle_conflicts(conflict_type='mapping')
+ def update_mapping(self, mapping_id, mapping):
+ ref = {}
+ ref['id'] = mapping_id
+ ref['rules'] = jsonutils.dumps(mapping.get('rules'))
+ with sql.transaction() as session:
+ mapping_ref = self._get_mapping(session, mapping_id)
+ old_mapping = mapping_ref.to_dict()
+ old_mapping.update(ref)
+ new_mapping = MappingModel.from_dict(old_mapping)
+ for attr in MappingModel.attributes:
+ setattr(mapping_ref, attr, getattr(new_mapping, attr))
+ return mapping_ref.to_dict()
+
+ def get_mapping_from_idp_and_protocol(self, idp_id, protocol_id):
+ with sql.transaction() as session:
+ protocol_ref = self._get_protocol(session, idp_id, protocol_id)
+ mapping_id = protocol_ref.mapping_id
+ mapping_ref = self._get_mapping(session, mapping_id)
+ return mapping_ref.to_dict()
+
+ # Service Provider CRUD
+ @sql.handle_conflicts(conflict_type='service_provider')
+ def create_sp(self, sp_id, sp):
+ sp['id'] = sp_id
+ with sql.transaction() as session:
+ sp_ref = ServiceProviderModel.from_dict(sp)
+ session.add(sp_ref)
+ return sp_ref.to_dict()
+
+ def delete_sp(self, sp_id):
+ with sql.transaction() as session:
+ sp_ref = self._get_sp(session, sp_id)
+ session.delete(sp_ref)
+
+ def _get_sp(self, session, sp_id):
+ sp_ref = session.query(ServiceProviderModel).get(sp_id)
+ if not sp_ref:
+ raise exception.ServiceProviderNotFound(sp_id=sp_id)
+ return sp_ref
+
+ def list_sps(self):
+ with sql.transaction() as session:
+ sps = session.query(ServiceProviderModel)
+ sps_list = [sp.to_dict() for sp in sps]
+ return sps_list
+
+ def get_sp(self, sp_id):
+ with sql.transaction() as session:
+ sp_ref = self._get_sp(session, sp_id)
+ return sp_ref.to_dict()
+
+ def update_sp(self, sp_id, sp):
+ with sql.transaction() as session:
+ sp_ref = self._get_sp(session, sp_id)
+ old_sp = sp_ref.to_dict()
+ old_sp.update(sp)
+ new_sp = ServiceProviderModel.from_dict(old_sp)
+ for attr in ServiceProviderModel.mutable_attributes:
+ setattr(sp_ref, attr, getattr(new_sp, attr))
+ return sp_ref.to_dict()
+
+ def get_enabled_service_providers(self):
+ with sql.transaction() as session:
+ service_providers = session.query(ServiceProviderModel)
+ service_providers = service_providers.filter_by(enabled=True)
+ return service_providers