aboutsummaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/contrib/federation/backends/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'keystone-moon/keystone/contrib/federation/backends/sql.py')
-rw-r--r--keystone-moon/keystone/contrib/federation/backends/sql.py71
1 files changed, 61 insertions, 10 deletions
diff --git a/keystone-moon/keystone/contrib/federation/backends/sql.py b/keystone-moon/keystone/contrib/federation/backends/sql.py
index f2c124d0..ed07c08f 100644
--- a/keystone-moon/keystone/contrib/federation/backends/sql.py
+++ b/keystone-moon/keystone/contrib/federation/backends/sql.py
@@ -17,6 +17,7 @@ from oslo_serialization import jsonutils
from keystone.common import sql
from keystone.contrib.federation import core
from keystone import exception
+from sqlalchemy import orm
class FederationProtocolModel(sql.ModelBase, sql.DictBase):
@@ -44,13 +45,53 @@ class FederationProtocolModel(sql.ModelBase, sql.DictBase):
class IdentityProviderModel(sql.ModelBase, sql.DictBase):
__tablename__ = 'identity_provider'
- attributes = ['id', 'remote_id', 'enabled', 'description']
- mutable_attributes = frozenset(['description', 'enabled', 'remote_id'])
+ attributes = ['id', 'enabled', 'description', 'remote_ids']
+ mutable_attributes = frozenset(['description', 'enabled', 'remote_ids'])
id = sql.Column(sql.String(64), primary_key=True)
- remote_id = sql.Column(sql.String(256), nullable=True)
enabled = sql.Column(sql.Boolean, nullable=False)
description = sql.Column(sql.Text(), nullable=True)
+ remote_ids = orm.relationship('IdPRemoteIdsModel',
+ order_by='IdPRemoteIdsModel.remote_id',
+ cascade='all, delete-orphan')
+
+ @classmethod
+ def from_dict(cls, dictionary):
+ new_dictionary = dictionary.copy()
+ remote_ids_list = new_dictionary.pop('remote_ids', None)
+ if not remote_ids_list:
+ remote_ids_list = []
+ identity_provider = cls(**new_dictionary)
+ remote_ids = []
+ # NOTE(fmarco76): the remote_ids_list contains only remote ids
+ # associated with the IdP because of the "relationship" established in
+ # sqlalchemy and corresponding to the FK in the idp_remote_ids table
+ for remote in remote_ids_list:
+ remote_ids.append(IdPRemoteIdsModel(remote_id=remote))
+ identity_provider.remote_ids = remote_ids
+ return identity_provider
+
+ def to_dict(self):
+ """Return a dictionary with model's attributes."""
+ d = dict()
+ for attr in self.__class__.attributes:
+ d[attr] = getattr(self, attr)
+ d['remote_ids'] = []
+ for remote in self.remote_ids:
+ d['remote_ids'].append(remote.remote_id)
+ return d
+
+
+class IdPRemoteIdsModel(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'idp_remote_ids'
+ attributes = ['idp_id', 'remote_id']
+ mutable_attributes = frozenset(['idp_id', 'remote_id'])
+
+ idp_id = sql.Column(sql.String(64),
+ sql.ForeignKey('identity_provider.id',
+ ondelete='CASCADE'))
+ remote_id = sql.Column(sql.String(255),
+ primary_key=True)
@classmethod
def from_dict(cls, dictionary):
@@ -75,6 +116,7 @@ class MappingModel(sql.ModelBase, sql.DictBase):
@classmethod
def from_dict(cls, dictionary):
new_dictionary = dictionary.copy()
+ new_dictionary['rules'] = jsonutils.dumps(new_dictionary['rules'])
return cls(**new_dictionary)
def to_dict(self):
@@ -82,20 +124,23 @@ class MappingModel(sql.ModelBase, sql.DictBase):
d = dict()
for attr in self.__class__.attributes:
d[attr] = getattr(self, attr)
+ d['rules'] = jsonutils.loads(d['rules'])
return d
class ServiceProviderModel(sql.ModelBase, sql.DictBase):
__tablename__ = 'service_provider'
- attributes = ['auth_url', 'id', 'enabled', 'description', 'sp_url']
+ attributes = ['auth_url', 'id', 'enabled', 'description',
+ 'relay_state_prefix', 'sp_url']
mutable_attributes = frozenset(['auth_url', 'description', 'enabled',
- 'sp_url'])
+ 'relay_state_prefix', 'sp_url'])
id = sql.Column(sql.String(64), primary_key=True)
enabled = sql.Column(sql.Boolean, nullable=False)
description = sql.Column(sql.Text(), nullable=True)
auth_url = sql.Column(sql.String(256), nullable=False)
sp_url = sql.Column(sql.String(256), nullable=False)
+ relay_state_prefix = sql.Column(sql.String(256), nullable=False)
@classmethod
def from_dict(cls, dictionary):
@@ -123,6 +168,7 @@ class Federation(core.Driver):
def delete_idp(self, idp_id):
with sql.transaction() as session:
+ self._delete_assigned_protocols(session, idp_id)
idp_ref = self._get_idp(session, idp_id)
session.delete(idp_ref)
@@ -133,7 +179,7 @@ class Federation(core.Driver):
return idp_ref
def _get_idp_from_remote_id(self, session, remote_id):
- q = session.query(IdentityProviderModel)
+ q = session.query(IdPRemoteIdsModel)
q = q.filter_by(remote_id=remote_id)
try:
return q.one()
@@ -153,8 +199,8 @@ class Federation(core.Driver):
def get_idp_from_remote_id(self, remote_id):
with sql.transaction() as session:
- idp_ref = self._get_idp_from_remote_id(session, remote_id)
- return idp_ref.to_dict()
+ ref = self._get_idp_from_remote_id(session, remote_id)
+ return ref.to_dict()
def update_idp(self, idp_id, idp):
with sql.transaction() as session:
@@ -214,6 +260,11 @@ class Federation(core.Driver):
key_ref = self._get_protocol(session, idp_id, protocol_id)
session.delete(key_ref)
+ def _delete_assigned_protocols(self, session, idp_id):
+ query = session.query(FederationProtocolModel)
+ query = query.filter_by(idp_id=idp_id)
+ query.delete()
+
# Mapping CRUD
def _get_mapping(self, session, mapping_id):
mapping_ref = session.query(MappingModel).get(mapping_id)
@@ -225,7 +276,7 @@ class Federation(core.Driver):
def create_mapping(self, mapping_id, mapping):
ref = {}
ref['id'] = mapping_id
- ref['rules'] = jsonutils.dumps(mapping.get('rules'))
+ ref['rules'] = mapping.get('rules')
with sql.transaction() as session:
mapping_ref = MappingModel.from_dict(ref)
session.add(mapping_ref)
@@ -250,7 +301,7 @@ class Federation(core.Driver):
def update_mapping(self, mapping_id, mapping):
ref = {}
ref['id'] = mapping_id
- ref['rules'] = jsonutils.dumps(mapping.get('rules'))
+ ref['rules'] = mapping.get('rules')
with sql.transaction() as session:
mapping_ref = self._get_mapping(session, mapping_id)
old_mapping = mapping_ref.to_dict()