From 2e7b4f2027a1147ca28301e4f88adf8274b39a1f Mon Sep 17 00:00:00 2001 From: DUVAL Thomas Date: Thu, 9 Jun 2016 09:11:50 +0200 Subject: Update Keystone core to Mitaka. Change-Id: Ia10d6add16f4a9d25d1f42d420661c46332e69db --- keystone-moon/keystone/identity/__init__.py | 1 - keystone-moon/keystone/identity/backends/ldap.py | 62 ++-- keystone-moon/keystone/identity/backends/sql.py | 345 +++++++++++++-------- keystone-moon/keystone/identity/controllers.py | 10 +- keystone-moon/keystone/identity/core.py | 303 +++++++++++++----- .../keystone/identity/mapping_backends/sql.py | 63 ++-- .../keystone/identity/shadow_backends/__init__.py | 0 .../keystone/identity/shadow_backends/sql.py | 73 +++++ 8 files changed, 611 insertions(+), 246 deletions(-) create mode 100644 keystone-moon/keystone/identity/shadow_backends/__init__.py create mode 100644 keystone-moon/keystone/identity/shadow_backends/sql.py (limited to 'keystone-moon/keystone/identity') diff --git a/keystone-moon/keystone/identity/__init__.py b/keystone-moon/keystone/identity/__init__.py index 3063b5ca..96b3ee77 100644 --- a/keystone-moon/keystone/identity/__init__.py +++ b/keystone-moon/keystone/identity/__init__.py @@ -15,4 +15,3 @@ from keystone.identity import controllers # noqa from keystone.identity.core import * # noqa from keystone.identity import generator # noqa -from keystone.identity import routers # noqa diff --git a/keystone-moon/keystone/identity/backends/ldap.py b/keystone-moon/keystone/identity/backends/ldap.py index 1f33bacb..fe8e8477 100644 --- a/keystone-moon/keystone/identity/backends/ldap.py +++ b/keystone-moon/keystone/identity/backends/ldap.py @@ -17,6 +17,7 @@ import uuid import ldap.filter from oslo_config import cfg from oslo_log import log +from oslo_log import versionutils import six from keystone.common import clean @@ -31,17 +32,20 @@ from keystone import identity CONF = cfg.CONF LOG = log.getLogger(__name__) +_DEPRECATION_MSG = _('%s for the LDAP identity backend has been deprecated in ' + 'the Mitaka release in favor of read-only identity LDAP ' + 'access. It will be removed in the "O" release.') + class Identity(identity.IdentityDriverV8): def __init__(self, conf=None): super(Identity, self).__init__() if conf is None: - conf = CONF - self.user = UserApi(conf) - self.group = GroupApi(conf) - - def default_assignment_driver(self): - return 'ldap' + self.conf = CONF + else: + self.conf = conf + self.user = UserApi(self.conf) + self.group = GroupApi(self.conf) def is_domain_aware(self): return False @@ -87,11 +91,15 @@ class Identity(identity.IdentityDriverV8): # CRUD def create_user(self, user_id, user): + msg = _DEPRECATION_MSG % "create_user" + versionutils.report_deprecated_feature(LOG, msg) self.user.check_allow_create() user_ref = self.user.create(user) return self.user.filter_attributes(user_ref) def update_user(self, user_id, user): + msg = _DEPRECATION_MSG % "update_user" + versionutils.report_deprecated_feature(LOG, msg) self.user.check_allow_update() old_obj = self.user.get(user_id) if 'name' in user and old_obj.get('name') != user['name']: @@ -110,6 +118,8 @@ class Identity(identity.IdentityDriverV8): return self.user.get_filtered(user_id) def delete_user(self, user_id): + msg = _DEPRECATION_MSG % "delete_user" + versionutils.report_deprecated_feature(LOG, msg) self.user.check_allow_delete() user = self.user.get(user_id) user_dn = user['dn'] @@ -122,6 +132,8 @@ class Identity(identity.IdentityDriverV8): self.user.delete(user_id) def create_group(self, group_id, group): + msg = _DEPRECATION_MSG % "create_group" + versionutils.report_deprecated_feature(LOG, msg) self.group.check_allow_create() group['name'] = clean.group_name(group['name']) return common_ldap.filter_entity(self.group.create(group)) @@ -135,28 +147,39 @@ class Identity(identity.IdentityDriverV8): return self.group.get_filtered_by_name(group_name) def update_group(self, group_id, group): + msg = _DEPRECATION_MSG % "update_group" + versionutils.report_deprecated_feature(LOG, msg) self.group.check_allow_update() if 'name' in group: group['name'] = clean.group_name(group['name']) return common_ldap.filter_entity(self.group.update(group_id, group)) def delete_group(self, group_id): + msg = _DEPRECATION_MSG % "delete_group" + versionutils.report_deprecated_feature(LOG, msg) self.group.check_allow_delete() return self.group.delete(group_id) def add_user_to_group(self, user_id, group_id): + msg = _DEPRECATION_MSG % "add_user_to_group" + versionutils.report_deprecated_feature(LOG, msg) user_ref = self._get_user(user_id) user_dn = user_ref['dn'] self.group.add_user(user_dn, group_id, user_id) def remove_user_from_group(self, user_id, group_id): + msg = _DEPRECATION_MSG % "remove_user_from_group" + versionutils.report_deprecated_feature(LOG, msg) user_ref = self._get_user(user_id) user_dn = user_ref['dn'] self.group.remove_user(user_dn, group_id, user_id) def list_groups_for_user(self, user_id, hints): user_ref = self._get_user(user_id) - user_dn = user_ref['dn'] + if self.conf.ldap.group_members_are_ids: + user_dn = user_ref['id'] + else: + user_dn = user_ref['dn'] return self.group.list_user_groups_filtered(user_dn, hints) def list_groups(self, hints): @@ -164,15 +187,19 @@ class Identity(identity.IdentityDriverV8): def list_users_in_group(self, group_id, hints): users = [] - for user_dn in self.group.list_group_users(group_id): - user_id = self.user._dn_to_id(user_dn) + for user_key in self.group.list_group_users(group_id): + if self.conf.ldap.group_members_are_ids: + user_id = user_key + else: + user_id = self.user._dn_to_id(user_key) + try: users.append(self.user.get_filtered(user_id)) except exception.UserNotFound: - LOG.debug(("Group member '%(user_dn)s' not found in" + LOG.debug(("Group member '%(user_key)s' not found in" " '%(group_id)s'. The user should be removed" " from the group. The user will be ignored."), - dict(user_dn=user_dn, group_id=group_id)) + dict(user_key=user_key, group_id=group_id)) return users def check_user_in_group(self, user_id, group_id): @@ -201,6 +228,7 @@ class UserApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap): attribute_options_names = {'password': 'pass', 'email': 'mail', 'name': 'name', + 'description': 'description', 'enabled': 'enabled', 'default_project_id': 'default_project_id'} immutable_attrs = ['id'] @@ -264,15 +292,15 @@ class UserApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap): return self.filter_attributes(user) def get_all_filtered(self, hints): - query = self.filter_query(hints) - return [self.filter_attributes(user) for user in self.get_all(query)] + query = self.filter_query(hints, self.ldap_filter) + return [self.filter_attributes(user) + for user in self.get_all(query, hints)] def filter_attributes(self, user): return identity.filter_user(common_ldap.filter_entity(user)) def is_user(self, dn): """Returns True if the entry is a user.""" - # NOTE(blk-u): It's easy to check if the DN is under the User tree, # but may not be accurate. A more accurate test would be to fetch the # entry to see if it's got the user objectclass, but this could be @@ -314,7 +342,7 @@ class GroupApi(common_ldap.BaseLdap): def delete(self, group_id): if self.subtree_delete_enabled: - super(GroupApi, self).deleteTree(group_id) + super(GroupApi, self).delete_tree(group_id) else: # TODO(spzala): this is only placeholder for group and domain # role support which will be added under bug 1101287 @@ -349,7 +377,6 @@ class GroupApi(common_ldap.BaseLdap): def list_user_groups(self, user_dn): """Return a list of groups for which the user is a member.""" - user_dn_esc = ldap.filter.escape_filter_chars(user_dn) query = '(%s=%s)%s' % (self.member_attribute, user_dn_esc, @@ -358,7 +385,6 @@ class GroupApi(common_ldap.BaseLdap): def list_user_groups_filtered(self, user_dn, hints): """Return a filtered list of groups for which the user is a member.""" - user_dn_esc = ldap.filter.escape_filter_chars(user_dn) query = '(%s=%s)%s' % (self.member_attribute, user_dn_esc, @@ -396,4 +422,4 @@ class GroupApi(common_ldap.BaseLdap): def get_all_filtered(self, hints, query=None): query = self.filter_query(hints, query) return [common_ldap.filter_entity(group) - for group in self.get_all(query)] + for group in self.get_all(query, hints)] diff --git a/keystone-moon/keystone/identity/backends/sql.py b/keystone-moon/keystone/identity/backends/sql.py index d37240eb..5680a8a2 100644 --- a/keystone-moon/keystone/identity/backends/sql.py +++ b/keystone-moon/keystone/identity/backends/sql.py @@ -12,8 +12,11 @@ # License for the specific language governing permissions and limitations # under the License. -from oslo_config import cfg +import sqlalchemy +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy import orm +from keystone.common import driver_hints from keystone.common import sql from keystone.common import utils from keystone import exception @@ -21,23 +24,84 @@ from keystone.i18n import _ from keystone import identity -CONF = cfg.CONF - - class User(sql.ModelBase, sql.DictBase): __tablename__ = 'user' attributes = ['id', 'name', 'domain_id', 'password', 'enabled', 'default_project_id'] id = sql.Column(sql.String(64), primary_key=True) - name = sql.Column(sql.String(255), nullable=False) - domain_id = sql.Column(sql.String(64), nullable=False) - password = sql.Column(sql.String(128)) enabled = sql.Column(sql.Boolean) extra = sql.Column(sql.JsonBlob()) default_project_id = sql.Column(sql.String(64)) - # Unique constraint across two columns to create the separation - # rather than just only 'name' being unique - __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {}) + local_user = orm.relationship('LocalUser', uselist=False, + single_parent=True, lazy='subquery', + cascade='all,delete-orphan', backref='user') + federated_users = orm.relationship('FederatedUser', + single_parent=True, + lazy='subquery', + cascade='all,delete-orphan', + backref='user') + + # name property + @hybrid_property + def name(self): + if self.local_user: + return self.local_user.name + elif self.federated_users: + return self.federated_users[0].display_name + else: + return None + + @name.setter + def name(self, value): + if not self.local_user: + self.local_user = LocalUser() + self.local_user.name = value + + @name.expression + def name(cls): + return LocalUser.name + + # password property + @hybrid_property + def password(self): + if self.local_user and self.local_user.passwords: + return self.local_user.passwords[0].password + else: + return None + + @password.setter + def password(self, value): + if not value: + if self.local_user and self.local_user.passwords: + self.local_user.passwords = [] + else: + if not self.local_user: + self.local_user = LocalUser() + if not self.local_user.passwords: + self.local_user.passwords.append(Password()) + self.local_user.passwords[0].password = value + + @password.expression + def password(cls): + return Password.password + + # domain_id property + @hybrid_property + def domain_id(self): + if self.local_user: + return self.local_user.domain_id + else: + return None + + @domain_id.setter + def domain_id(self, value): + if not self.local_user: + self.local_user = LocalUser() + self.local_user.domain_id = value + + @domain_id.expression + def domain_id(cls): + return LocalUser.domain_id def to_dict(self, include_extra_dict=False): d = super(User, self).to_dict(include_extra_dict=include_extra_dict) @@ -46,6 +110,49 @@ class User(sql.ModelBase, sql.DictBase): return d +class LocalUser(sql.ModelBase, sql.DictBase): + __tablename__ = 'local_user' + attributes = ['id', 'user_id', 'domain_id', 'name'] + id = sql.Column(sql.Integer, primary_key=True) + user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id', + ondelete='CASCADE'), unique=True) + domain_id = sql.Column(sql.String(64), nullable=False) + name = sql.Column(sql.String(255), nullable=False) + passwords = orm.relationship('Password', single_parent=True, + cascade='all,delete-orphan', + backref='local_user') + __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {}) + + +class Password(sql.ModelBase, sql.DictBase): + __tablename__ = 'password' + attributes = ['id', 'local_user_id', 'password'] + id = sql.Column(sql.Integer, primary_key=True) + local_user_id = sql.Column(sql.Integer, sql.ForeignKey('local_user.id', + ondelete='CASCADE')) + password = sql.Column(sql.String(128)) + + +class FederatedUser(sql.ModelBase, sql.ModelDictMixin): + __tablename__ = 'federated_user' + attributes = ['id', 'user_id', 'idp_id', 'protocol_id', 'unique_id', + 'display_name'] + id = sql.Column(sql.Integer, primary_key=True) + user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id', + ondelete='CASCADE')) + idp_id = sql.Column(sql.String(64), sql.ForeignKey('identity_provider.id', + ondelete='CASCADE')) + protocol_id = sql.Column(sql.String(64), nullable=False) + unique_id = sql.Column(sql.String(255), nullable=False) + display_name = sql.Column(sql.String(255), nullable=True) + __table_args__ = ( + sql.UniqueConstraint('idp_id', 'protocol_id', 'unique_id'), + sqlalchemy.ForeignKeyConstraint(['protocol_id', 'idp_id'], + ['federation_protocol.id', + 'federation_protocol.idp_id']) + ) + + class Group(sql.ModelBase, sql.DictBase): __tablename__ = 'group' attributes = ['id', 'name', 'domain_id', 'description'] @@ -56,11 +163,12 @@ class Group(sql.ModelBase, sql.DictBase): extra = sql.Column(sql.JsonBlob()) # Unique constraint across two columns to create the separation # rather than just only 'name' being unique - __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {}) + __table_args__ = (sql.UniqueConstraint('domain_id', 'name'),) class UserGroupMembership(sql.ModelBase, sql.DictBase): """Group membership join table.""" + __tablename__ = 'user_group_membership' user_id = sql.Column(sql.String(64), sql.ForeignKey('user.id'), @@ -74,11 +182,9 @@ class Identity(identity.IdentityDriverV8): # NOTE(henry-nash): Override the __init__() method so as to take a # config parameter to enable sql to be used as a domain-specific driver. def __init__(self, conf=None): + self.conf = conf super(Identity, self).__init__() - def default_assignment_driver(self): - return 'sql' - @property def is_sql(self): return True @@ -96,33 +202,32 @@ class Identity(identity.IdentityDriverV8): # Identity interface def authenticate(self, user_id, password): - session = sql.get_session() - user_ref = None - try: - user_ref = self._get_user(session, user_id) - except exception.UserNotFound: - raise AssertionError(_('Invalid user / password')) - if not self._check_password(password, user_ref): - raise AssertionError(_('Invalid user / password')) - return identity.filter_user(user_ref.to_dict()) + with sql.session_for_read() as session: + user_ref = None + try: + user_ref = self._get_user(session, user_id) + except exception.UserNotFound: + raise AssertionError(_('Invalid user / password')) + if not self._check_password(password, user_ref): + raise AssertionError(_('Invalid user / password')) + return identity.filter_user(user_ref.to_dict()) # user crud @sql.handle_conflicts(conflict_type='user') def create_user(self, user_id, user): user = utils.hash_user_password(user) - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: user_ref = User.from_dict(user) session.add(user_ref) - return identity.filter_user(user_ref.to_dict()) + return identity.filter_user(user_ref.to_dict()) - @sql.truncated + @driver_hints.truncated def list_users(self, hints): - session = sql.get_session() - query = session.query(User) - user_refs = sql.filter_limit_query(User, query, hints) - return [identity.filter_user(x.to_dict()) for x in user_refs] + with sql.session_for_read() as session: + query = session.query(User).outerjoin(LocalUser) + user_refs = sql.filter_limit_query(User, query, hints) + return [identity.filter_user(x.to_dict()) for x in user_refs] def _get_user(self, session, user_id): user_ref = session.query(User).get(user_id) @@ -131,25 +236,24 @@ class Identity(identity.IdentityDriverV8): return user_ref def get_user(self, user_id): - session = sql.get_session() - return identity.filter_user(self._get_user(session, user_id).to_dict()) + with sql.session_for_read() as session: + return identity.filter_user( + self._get_user(session, user_id).to_dict()) def get_user_by_name(self, user_name, domain_id): - session = sql.get_session() - query = session.query(User) - query = query.filter_by(name=user_name) - query = query.filter_by(domain_id=domain_id) - try: - user_ref = query.one() - except sql.NotFound: - raise exception.UserNotFound(user_id=user_name) - return identity.filter_user(user_ref.to_dict()) + with sql.session_for_read() as session: + query = session.query(User).join(LocalUser) + query = query.filter(sqlalchemy.and_(LocalUser.name == user_name, + LocalUser.domain_id == domain_id)) + try: + user_ref = query.one() + except sql.NotFound: + raise exception.UserNotFound(user_id=user_name) + return identity.filter_user(user_ref.to_dict()) @sql.handle_conflicts(conflict_type='user') def update_user(self, user_id, user): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: user_ref = self._get_user(session, user_id) old_user_dict = user_ref.to_dict() user = utils.hash_user_password(user) @@ -160,76 +264,74 @@ class Identity(identity.IdentityDriverV8): if attr != 'id': setattr(user_ref, attr, getattr(new_user, attr)) user_ref.extra = new_user.extra - return identity.filter_user(user_ref.to_dict(include_extra_dict=True)) + return identity.filter_user( + user_ref.to_dict(include_extra_dict=True)) def add_user_to_group(self, user_id, group_id): - session = sql.get_session() - self.get_group(group_id) - self.get_user(user_id) - query = session.query(UserGroupMembership) - query = query.filter_by(user_id=user_id) - query = query.filter_by(group_id=group_id) - rv = query.first() - if rv: - return - - with session.begin(): + with sql.session_for_write() as session: + self.get_group(group_id) + self.get_user(user_id) + query = session.query(UserGroupMembership) + query = query.filter_by(user_id=user_id) + query = query.filter_by(group_id=group_id) + rv = query.first() + if rv: + return + session.add(UserGroupMembership(user_id=user_id, group_id=group_id)) def check_user_in_group(self, user_id, group_id): - session = sql.get_session() - self.get_group(group_id) - self.get_user(user_id) - query = session.query(UserGroupMembership) - query = query.filter_by(user_id=user_id) - query = query.filter_by(group_id=group_id) - if not query.first(): - raise exception.NotFound(_("User '%(user_id)s' not found in" - " group '%(group_id)s'") % - {'user_id': user_id, - 'group_id': group_id}) + with sql.session_for_read() as session: + self.get_group(group_id) + self.get_user(user_id) + query = session.query(UserGroupMembership) + query = query.filter_by(user_id=user_id) + query = query.filter_by(group_id=group_id) + if not query.first(): + raise exception.NotFound(_("User '%(user_id)s' not found in" + " group '%(group_id)s'") % + {'user_id': user_id, + 'group_id': group_id}) def remove_user_from_group(self, user_id, group_id): - session = sql.get_session() # We don't check if user or group are still valid and let the remove # be tried anyway - in case this is some kind of clean-up operation - query = session.query(UserGroupMembership) - query = query.filter_by(user_id=user_id) - query = query.filter_by(group_id=group_id) - membership_ref = query.first() - if membership_ref is None: - # Check if the group and user exist to return descriptive - # exceptions. - self.get_group(group_id) - self.get_user(user_id) - raise exception.NotFound(_("User '%(user_id)s' not found in" - " group '%(group_id)s'") % - {'user_id': user_id, - 'group_id': group_id}) - with session.begin(): + with sql.session_for_write() as session: + query = session.query(UserGroupMembership) + query = query.filter_by(user_id=user_id) + query = query.filter_by(group_id=group_id) + membership_ref = query.first() + if membership_ref is None: + # Check if the group and user exist to return descriptive + # exceptions. + self.get_group(group_id) + self.get_user(user_id) + raise exception.NotFound(_("User '%(user_id)s' not found in" + " group '%(group_id)s'") % + {'user_id': user_id, + 'group_id': group_id}) session.delete(membership_ref) def list_groups_for_user(self, user_id, hints): - session = sql.get_session() - self.get_user(user_id) - query = session.query(Group).join(UserGroupMembership) - query = query.filter(UserGroupMembership.user_id == user_id) - query = sql.filter_limit_query(Group, query, hints) - return [g.to_dict() for g in query] + with sql.session_for_read() as session: + self.get_user(user_id) + query = session.query(Group).join(UserGroupMembership) + query = query.filter(UserGroupMembership.user_id == user_id) + query = sql.filter_limit_query(Group, query, hints) + return [g.to_dict() for g in query] def list_users_in_group(self, group_id, hints): - session = sql.get_session() - self.get_group(group_id) - query = session.query(User).join(UserGroupMembership) - query = query.filter(UserGroupMembership.group_id == group_id) - query = sql.filter_limit_query(User, query, hints) - return [identity.filter_user(u.to_dict()) for u in query] + with sql.session_for_read() as session: + self.get_group(group_id) + query = session.query(User).outerjoin(LocalUser) + query = query.join(UserGroupMembership) + query = query.filter(UserGroupMembership.group_id == group_id) + query = sql.filter_limit_query(User, query, hints) + return [identity.filter_user(u.to_dict()) for u in query] def delete_user(self, user_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_user(session, user_id) q = session.query(UserGroupMembership) @@ -242,18 +344,17 @@ class Identity(identity.IdentityDriverV8): @sql.handle_conflicts(conflict_type='group') def create_group(self, group_id, group): - session = sql.get_session() - with session.begin(): + with sql.session_for_write() as session: ref = Group.from_dict(group) session.add(ref) - return ref.to_dict() + return ref.to_dict() - @sql.truncated + @driver_hints.truncated def list_groups(self, hints): - session = sql.get_session() - query = session.query(Group) - refs = sql.filter_limit_query(Group, query, hints) - return [ref.to_dict() for ref in refs] + with sql.session_for_read() as session: + query = session.query(Group) + refs = sql.filter_limit_query(Group, query, hints) + return [ref.to_dict() for ref in refs] def _get_group(self, session, group_id): ref = session.query(Group).get(group_id) @@ -262,25 +363,23 @@ class Identity(identity.IdentityDriverV8): return ref def get_group(self, group_id): - session = sql.get_session() - return self._get_group(session, group_id).to_dict() + with sql.session_for_read() as session: + return self._get_group(session, group_id).to_dict() def get_group_by_name(self, group_name, domain_id): - session = sql.get_session() - query = session.query(Group) - query = query.filter_by(name=group_name) - query = query.filter_by(domain_id=domain_id) - try: - group_ref = query.one() - except sql.NotFound: - raise exception.GroupNotFound(group_id=group_name) - return group_ref.to_dict() + with sql.session_for_read() as session: + query = session.query(Group) + query = query.filter_by(name=group_name) + query = query.filter_by(domain_id=domain_id) + try: + group_ref = query.one() + except sql.NotFound: + raise exception.GroupNotFound(group_id=group_name) + return group_ref.to_dict() @sql.handle_conflicts(conflict_type='group') def update_group(self, group_id, group): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_group(session, group_id) old_dict = ref.to_dict() for k in group: @@ -290,12 +389,10 @@ class Identity(identity.IdentityDriverV8): if attr != 'id': setattr(ref, attr, getattr(new_group, attr)) ref.extra = new_group.extra - return ref.to_dict() + return ref.to_dict() def delete_group(self, group_id): - session = sql.get_session() - - with session.begin(): + with sql.session_for_write() as session: ref = self._get_group(session, group_id) q = session.query(UserGroupMembership) diff --git a/keystone-moon/keystone/identity/controllers.py b/keystone-moon/keystone/identity/controllers.py index 0ec38190..9e8ba6fc 100644 --- a/keystone-moon/keystone/identity/controllers.py +++ b/keystone-moon/keystone/identity/controllers.py @@ -80,6 +80,8 @@ class User(controller.V2Controller): self.resource_api.get_project(default_project_id) user['default_project_id'] = default_project_id + self.resource_api.ensure_default_domain_exists() + # The manager layer will generate the unique ID for users user_ref = self._normalize_domain_id(context, user.copy()) initiator = notifications._get_request_audit_info(context) @@ -149,7 +151,7 @@ class User(controller.V2Controller): try: self.assignment_api.add_user_to_project( user_ref['tenantId'], user_id) - except exception.Conflict: + except exception.Conflict: # nosec # We are already a member of that tenant pass except exception.NotFound: @@ -253,7 +255,8 @@ class UserV3(controller.V3Controller): @controller.protected(callback=_check_user_and_group_protection) def add_user_to_group(self, context, user_id, group_id): - self.identity_api.add_user_to_group(user_id, group_id) + initiator = notifications._get_request_audit_info(context) + self.identity_api.add_user_to_group(user_id, group_id, initiator) @controller.protected(callback=_check_user_and_group_protection) def check_user_in_group(self, context, user_id, group_id): @@ -261,7 +264,8 @@ class UserV3(controller.V3Controller): @controller.protected(callback=_check_user_and_group_protection) def remove_user_from_group(self, context, user_id, group_id): - self.identity_api.remove_user_from_group(user_id, group_id) + initiator = notifications._get_request_audit_info(context) + self.identity_api.remove_user_from_group(user_id, group_id, initiator) @controller.protected() def delete_user(self, context, user_id): diff --git a/keystone-moon/keystone/identity/core.py b/keystone-moon/keystone/identity/core.py index 061b82e1..2f52a358 100644 --- a/keystone-moon/keystone/identity/core.py +++ b/keystone-moon/keystone/identity/core.py @@ -17,18 +17,21 @@ import abc import functools import os +import threading import uuid from oslo_config import cfg from oslo_log import log +from oslo_log import versionutils import six +from keystone import assignment # TODO(lbragstad): Decouple this dependency from keystone.common import cache from keystone.common import clean +from keystone.common import config from keystone.common import dependency from keystone.common import driver_hints from keystone.common import manager -from keystone import config from keystone import exception from keystone.i18n import _, _LW from keystone.identity.mapping_backends import mapping @@ -39,7 +42,7 @@ CONF = cfg.CONF LOG = log.getLogger(__name__) -MEMOIZE = cache.get_memoization_decorator(section='identity') +MEMOIZE = cache.get_memoization_decorator(group='identity') DOMAIN_CONF_FHEAD = 'keystone.' DOMAIN_CONF_FTAIL = '.conf' @@ -70,7 +73,8 @@ def filter_user(user_ref): try: user_ref['extra'].pop('password', None) user_ref['extra'].pop('tenants', None) - except KeyError: + except KeyError: # nosec + # ok to not have extra in the user_ref. pass return user_ref @@ -92,43 +96,33 @@ class DomainConfigs(dict): the identity manager and driver can use. """ + configured = False driver = None _any_sql = False + lock = threading.Lock() def _load_driver(self, domain_config): return manager.load_driver(Manager.driver_namespace, domain_config['cfg'].identity.driver, domain_config['cfg']) - def _assert_no_more_than_one_sql_driver(self, domain_id, new_config, - config_file=None): - """Ensure there is no more than one sql driver. - - Check to see if the addition of the driver in this new config - would cause there to now be more than one sql driver. + def _load_config_from_file(self, resource_api, file_list, domain_name): - If we are loading from configuration files, the config_file will hold - the name of the file we have just loaded. + def _assert_no_more_than_one_sql_driver(domain_id, new_config, + config_file): + """Ensure there is no more than one sql driver. - """ - if (new_config['driver'].is_sql and - (self.driver.is_sql or self._any_sql)): - # The addition of this driver would cause us to have more than - # one sql driver, so raise an exception. - - # TODO(henry-nash): This method is only used in the file-based - # case, so has no need to worry about the database/API case. The - # code that overrides config_file below is therefore never used - # and should be removed, and this method perhaps moved inside - # _load_config_from_file(). This is raised as bug #1466772. - - if not config_file: - config_file = _('Database at /domains/%s/config') % domain_id - raise exception.MultipleSQLDriversInConfig(source=config_file) - self._any_sql = self._any_sql or new_config['driver'].is_sql + Check to see if the addition of the driver in this new config + would cause there to be more than one sql driver. - def _load_config_from_file(self, resource_api, file_list, domain_name): + """ + if (new_config['driver'].is_sql and + (self.driver.is_sql or self._any_sql)): + # The addition of this driver would cause us to have more than + # one sql driver, so raise an exception. + raise exception.MultipleSQLDriversInConfig(source=config_file) + self._any_sql = self._any_sql or new_config['driver'].is_sql try: domain_ref = resource_api.get_domain_by_name(domain_name) @@ -149,9 +143,9 @@ class DomainConfigs(dict): domain_config['cfg'](args=[], project='keystone', default_config_files=file_list) domain_config['driver'] = self._load_driver(domain_config) - self._assert_no_more_than_one_sql_driver(domain_ref['id'], - domain_config, - config_file=file_list) + _assert_no_more_than_one_sql_driver(domain_ref['id'], + domain_config, + file_list) self[domain_ref['id']] = domain_config def _setup_domain_drivers_from_files(self, standard_driver, resource_api): @@ -275,7 +269,7 @@ class DomainConfigs(dict): # being able to find who has it...either we were very very very # unlucky or something is awry. msg = _('Exceeded attempts to register domain %(domain)s to use ' - 'the SQL driver, the last domain that appears to have ' + 'the SQL driver, the last domain that appears to have ' 'had it is %(last_domain)s, giving up') % { 'domain': domain_id, 'last_domain': domain_registered} raise exception.UnexpectedError(msg) @@ -322,7 +316,6 @@ class DomainConfigs(dict): def setup_domain_drivers(self, standard_driver, resource_api): # This is called by the api call wrapper - self.configured = True self.driver = standard_driver if CONF.identity.domain_configurations_from_database: @@ -331,6 +324,7 @@ class DomainConfigs(dict): else: self._setup_domain_drivers_from_files(standard_driver, resource_api) + self.configured = True def get_domain_driver(self, domain_id): self.check_config_and_reload_domain_driver_if_required(domain_id) @@ -404,7 +398,7 @@ class DomainConfigs(dict): # specific driver for this domain. try: del self[domain_id] - except KeyError: + except KeyError: # nosec # Allow this error in case we are unlucky and in a # multi-threaded situation, two threads happen to be running # in lock step. @@ -428,15 +422,20 @@ def domains_configured(f): def wrapper(self, *args, **kwargs): if (not self.domain_configs.configured and CONF.identity.domain_specific_drivers_enabled): - self.domain_configs.setup_domain_drivers( - self.driver, self.resource_api) + # If domain specific driver has not been configured, acquire the + # lock and proceed with loading the driver. + with self.domain_configs.lock: + # Check again just in case some other thread has already + # completed domain config. + if not self.domain_configs.configured: + self.domain_configs.setup_domain_drivers( + self.driver, self.resource_api) return f(self, *args, **kwargs) return wrapper def exception_translated(exception_type): """Wraps API calls to map to correct exception.""" - def _exception_translated(f): @functools.wraps(f) def wrapper(self, *args, **kwargs): @@ -458,7 +457,7 @@ def exception_translated(exception_type): @notifications.listener @dependency.provider('identity_api') @dependency.requires('assignment_api', 'credential_api', 'id_mapping_api', - 'resource_api', 'revoke_api') + 'resource_api', 'revoke_api', 'shadow_users_api') class Manager(manager.Manager): """Default pivot point for the Identity backend. @@ -710,7 +709,7 @@ class Manager(manager.Manager): Use the mapping table to look up the domain, driver and local entity that is represented by the provided public ID. Handle the situations - were we do not use the mapping (e.g. single driver that understands + where we do not use the mapping (e.g. single driver that understands UUIDs etc.) """ @@ -799,6 +798,41 @@ class Manager(manager.Manager): not hints.get_exact_filter_by_name('domain_id')): hints.add_filter('domain_id', domain_id) + def _set_list_limit_in_hints(self, hints, driver): + """Set list limit in hints from driver + + If a hints list is provided, the wrapper will insert the relevant + limit into the hints so that the underlying driver call can try and + honor it. If the driver does truncate the response, it will update the + 'truncated' attribute in the 'limit' entry in the hints list, which + enables the caller of this function to know if truncation has taken + place. If, however, the driver layer is unable to perform truncation, + the 'limit' entry is simply left in the hints list for the caller to + handle. + + A _get_list_limit() method is required to be present in the object + class hierarchy, which returns the limit for this backend to which + we will truncate. + + If a hints list is not provided in the arguments of the wrapped call + then any limits set in the config file are ignored. This allows + internal use of such wrapped methods where the entire data set is + needed as input for the calculations of some other API (e.g. get role + assignments for a given project). + + This method, specific to identity manager, is used instead of more + general response_truncated, because the limit for identity entities + can be overriden in domain-specific config files. The driver to use + is determined during processing of the passed parameters and + response_truncated is designed to set the limit before any processing. + """ + if hints is None: + return + + list_limit = driver._get_list_limit() + if list_limit: + hints.set_limit(list_limit) + # The actual driver calls - these are pre/post processed here as # part of the Manager layer to make sure we: # @@ -869,11 +903,11 @@ class Manager(manager.Manager): return self._set_domain_id_and_mapping( ref, domain_id, driver, mapping.EntityType.USER) - @manager.response_truncated @domains_configured @exception_translated('user') def list_users(self, domain_scope=None, hints=None): driver = self._select_identity_driver(domain_scope) + self._set_list_limit_in_hints(hints, driver) hints = hints or driver_hints.Hints() if driver.is_domain_aware(): # Force the domain_scope into the hint to ensure that we only get @@ -887,6 +921,14 @@ class Manager(manager.Manager): return self._set_domain_id_and_mapping( ref_list, domain_scope, driver, mapping.EntityType.USER) + def _check_update_of_domain_id(self, new_domain, old_domain): + if new_domain != old_domain: + versionutils.report_deprecated_feature( + LOG, + _('update of domain_id is deprecated as of Mitaka ' + 'and will be removed in O.') + ) + @domains_configured @exception_translated('user') def update_user(self, user_id, user_ref, initiator=None): @@ -897,6 +939,8 @@ class Manager(manager.Manager): if 'enabled' in user: user['enabled'] = clean.user_enabled(user['enabled']) if 'domain_id' in user: + self._check_update_of_domain_id(user['domain_id'], + old_user_ref['domain_id']) self.resource_api.get_domain(user['domain_id']) if 'id' in user: if user_id != user['id']: @@ -941,6 +985,10 @@ class Manager(manager.Manager): self.id_mapping_api.delete_id_mapping(user_id) notifications.Audit.deleted(self._USER, user_id, initiator) + # Invalidate user role assignments cache region, as it may be caching + # role assignments where the actor is the specified user + assignment.COMPUTED_ASSIGNMENTS_REGION.invalidate() + @domains_configured @exception_translated('group') def create_group(self, group_ref, initiator=None): @@ -986,6 +1034,9 @@ class Manager(manager.Manager): @exception_translated('group') def update_group(self, group_id, group, initiator=None): if 'domain_id' in group: + old_group_ref = self.get_group(group_id) + self._check_update_of_domain_id(group['domain_id'], + old_group_ref['domain_id']) self.resource_api.get_domain(group['domain_id']) domain_id, driver, entity_id = ( self._get_domain_driver_and_entity_id(group_id)) @@ -1012,9 +1063,13 @@ class Manager(manager.Manager): for uid in user_ids: self.emit_invalidate_user_token_persistence(uid) + # Invalidate user role assignments cache region, as it may be caching + # role assignments expanded from the specified group to its users + assignment.COMPUTED_ASSIGNMENTS_REGION.invalidate() + @domains_configured @exception_translated('group') - def add_user_to_group(self, user_id, group_id): + def add_user_to_group(self, user_id, group_id, initiator=None): @exception_translated('user') def get_entity_info_for_user(public_id): return self._get_domain_driver_and_entity_id(public_id) @@ -1031,9 +1086,15 @@ class Manager(manager.Manager): group_driver.add_user_to_group(user_entity_id, group_entity_id) + # Invalidate user role assignments cache region, as it may now need to + # include role assignments from the specified group to its users + assignment.COMPUTED_ASSIGNMENTS_REGION.invalidate() + notifications.Audit.added_to(self._GROUP, group_id, self._USER, + user_id, initiator) + @domains_configured @exception_translated('group') - def remove_user_from_group(self, user_id, group_id): + def remove_user_from_group(self, user_id, group_id, initiator=None): @exception_translated('user') def get_entity_info_for_user(public_id): return self._get_domain_driver_and_entity_id(public_id) @@ -1051,7 +1112,12 @@ class Manager(manager.Manager): group_driver.remove_user_from_group(user_entity_id, group_entity_id) self.emit_invalidate_user_token_persistence(user_id) - @notifications.internal(notifications.INVALIDATE_USER_TOKEN_PERSISTENCE) + # Invalidate user role assignments cache region, as it may be caching + # role assignments expanded from this group to this user + assignment.COMPUTED_ASSIGNMENTS_REGION.invalidate() + notifications.Audit.removed_from(self._GROUP, group_id, self._USER, + user_id, initiator) + def emit_invalidate_user_token_persistence(self, user_id): """Emit a notification to the callback system to revoke user tokens. @@ -1061,10 +1127,10 @@ class Manager(manager.Manager): :param user_id: user identifier :type user_id: string """ - pass + notifications.Audit.internal( + notifications.INVALIDATE_USER_TOKEN_PERSISTENCE, user_id + ) - @notifications.internal( - notifications.INVALIDATE_USER_PROJECT_TOKEN_PERSISTENCE) def emit_invalidate_grant_token_persistence(self, user_project): """Emit a notification to the callback system to revoke grant tokens. @@ -1074,14 +1140,17 @@ class Manager(manager.Manager): :param user_project: {'user_id': user_id, 'project_id': project_id} :type user_project: dict """ - pass + notifications.Audit.internal( + notifications.INVALIDATE_USER_PROJECT_TOKEN_PERSISTENCE, + user_project + ) - @manager.response_truncated @domains_configured @exception_translated('user') def list_groups_for_user(self, user_id, hints=None): domain_id, driver, entity_id = ( self._get_domain_driver_and_entity_id(user_id)) + self._set_list_limit_in_hints(hints, driver) hints = hints or driver_hints.Hints() if not driver.is_domain_aware(): # We are effectively satisfying any domain_id filter by the above @@ -1091,11 +1160,11 @@ class Manager(manager.Manager): return self._set_domain_id_and_mapping( ref_list, domain_id, driver, mapping.EntityType.GROUP) - @manager.response_truncated @domains_configured @exception_translated('group') def list_groups(self, domain_scope=None, hints=None): driver = self._select_identity_driver(domain_scope) + self._set_list_limit_in_hints(hints, driver) hints = hints or driver_hints.Hints() if driver.is_domain_aware(): # Force the domain_scope into the hint to ensure that we only get @@ -1109,12 +1178,12 @@ class Manager(manager.Manager): return self._set_domain_id_and_mapping( ref_list, domain_scope, driver, mapping.EntityType.GROUP) - @manager.response_truncated @domains_configured @exception_translated('group') def list_users_in_group(self, group_id, hints=None): domain_id, driver, entity_id = ( self._get_domain_driver_and_entity_id(group_id)) + self._set_list_limit_in_hints(hints, driver) hints = hints or driver_hints.Hints() if not driver.is_domain_aware(): # We are effectively satisfying any domain_id filter by the above @@ -1154,18 +1223,62 @@ class Manager(manager.Manager): update_dict = {'password': new_password} self.update_user(user_id, update_dict) + @MEMOIZE + def shadow_federated_user(self, idp_id, protocol_id, unique_id, + display_name): + """Shadows a federated user by mapping to a user. + + :param idp_id: identity provider id + :param protocol_id: protocol id + :param unique_id: unique id for the user within the IdP + :param display_name: user's display name + + :returns: dictionary of the mapped User entity + """ + user_dict = {} + try: + self.shadow_users_api.update_federated_user_display_name( + idp_id, protocol_id, unique_id, display_name) + user_dict = self.shadow_users_api.get_federated_user( + idp_id, protocol_id, unique_id) + except exception.UserNotFound: + federated_dict = { + 'idp_id': idp_id, + 'protocol_id': protocol_id, + 'unique_id': unique_id, + 'display_name': display_name + } + user_dict = self.shadow_users_api.create_federated_user( + federated_dict) + return user_dict + @six.add_metaclass(abc.ABCMeta) class IdentityDriverV8(object): """Interface description for an Identity driver.""" + def _get_conf(self): + try: + return self.conf or CONF + except AttributeError: + return CONF + def _get_list_limit(self): - return CONF.identity.list_limit or CONF.list_limit + conf = self._get_conf() + # use list_limit from domain-specific config. If list_limit in + # domain-specific config is not set, look it up in the default config + return (conf.identity.list_limit or conf.list_limit or + CONF.identity.list_limit or CONF.list_limit) def is_domain_aware(self): """Indicates if Driver supports domains.""" return True + def default_assignment_driver(self): + # TODO(morganfainberg): To be removed when assignment driver based + # upon [identity]/driver option is removed in the "O" release. + return 'sql' + @property def is_sql(self): """Indicates if this Driver uses SQL.""" @@ -1183,8 +1296,9 @@ class IdentityDriverV8(object): @abc.abstractmethod def authenticate(self, user_id, password): """Authenticate a given user and password. + :returns: user_ref - :raises: AssertionError + :raises AssertionError: If user or password is invalid. """ raise exception.NotImplemented() # pragma: no cover @@ -1194,7 +1308,7 @@ class IdentityDriverV8(object): def create_user(self, user_id, user): """Creates a new user. - :raises: keystone.exception.Conflict + :raises keystone.exception.Conflict: If a duplicate user exists. """ raise exception.NotImplemented() # pragma: no cover @@ -1229,7 +1343,7 @@ class IdentityDriverV8(object): """Get a user by ID. :returns: user_ref - :raises: keystone.exception.UserNotFound + :raises keystone.exception.UserNotFound: If the user doesn't exist. """ raise exception.NotImplemented() # pragma: no cover @@ -1238,8 +1352,8 @@ class IdentityDriverV8(object): def update_user(self, user_id, user): """Updates an existing user. - :raises: keystone.exception.UserNotFound, - keystone.exception.Conflict + :raises keystone.exception.UserNotFound: If the user doesn't exist. + :raises keystone.exception.Conflict: If a duplicate user exists. """ raise exception.NotImplemented() # pragma: no cover @@ -1248,8 +1362,8 @@ class IdentityDriverV8(object): def add_user_to_group(self, user_id, group_id): """Adds a user to a group. - :raises: keystone.exception.UserNotFound, - keystone.exception.GroupNotFound + :raises keystone.exception.UserNotFound: If the user doesn't exist. + :raises keystone.exception.GroupNotFound: If the group doesn't exist. """ raise exception.NotImplemented() # pragma: no cover @@ -1258,8 +1372,8 @@ class IdentityDriverV8(object): def check_user_in_group(self, user_id, group_id): """Checks if a user is a member of a group. - :raises: keystone.exception.UserNotFound, - keystone.exception.GroupNotFound + :raises keystone.exception.UserNotFound: If the user doesn't exist. + :raises keystone.exception.GroupNotFound: If the group doesn't exist. """ raise exception.NotImplemented() # pragma: no cover @@ -1268,7 +1382,7 @@ class IdentityDriverV8(object): def remove_user_from_group(self, user_id, group_id): """Removes a user from a group. - :raises: keystone.exception.NotFound + :raises keystone.exception.NotFound: If the entity not found. """ raise exception.NotImplemented() # pragma: no cover @@ -1277,7 +1391,7 @@ class IdentityDriverV8(object): def delete_user(self, user_id): """Deletes an existing user. - :raises: keystone.exception.UserNotFound + :raises keystone.exception.UserNotFound: If the user doesn't exist. """ raise exception.NotImplemented() # pragma: no cover @@ -1287,7 +1401,7 @@ class IdentityDriverV8(object): """Get a user by name. :returns: user_ref - :raises: keystone.exception.UserNotFound + :raises keystone.exception.UserNotFound: If the user doesn't exist. """ raise exception.NotImplemented() # pragma: no cover @@ -1298,7 +1412,7 @@ class IdentityDriverV8(object): def create_group(self, group_id, group): """Creates a new group. - :raises: keystone.exception.Conflict + :raises keystone.exception.Conflict: If a duplicate group exists. """ raise exception.NotImplemented() # pragma: no cover @@ -1333,7 +1447,7 @@ class IdentityDriverV8(object): """Get a group by ID. :returns: group_ref - :raises: keystone.exception.GroupNotFound + :raises keystone.exception.GroupNotFound: If the group doesn't exist. """ raise exception.NotImplemented() # pragma: no cover @@ -1343,7 +1457,7 @@ class IdentityDriverV8(object): """Get a group by name. :returns: group_ref - :raises: keystone.exception.GroupNotFound + :raises keystone.exception.GroupNotFound: If the group doesn't exist. """ raise exception.NotImplemented() # pragma: no cover @@ -1352,8 +1466,8 @@ class IdentityDriverV8(object): def update_group(self, group_id, group): """Updates an existing group. - :raises: keystone.exceptionGroupNotFound, - keystone.exception.Conflict + :raises keystone.exception.GroupNotFound: If the group doesn't exist. + :raises keystone.exception.Conflict: If a duplicate group exists. """ raise exception.NotImplemented() # pragma: no cover @@ -1362,7 +1476,7 @@ class IdentityDriverV8(object): def delete_group(self, group_id): """Deletes an existing group. - :raises: keystone.exception.GroupNotFound + :raises keystone.exception.GroupNotFound: If the group doesn't exist. """ raise exception.NotImplemented() # pragma: no cover @@ -1446,3 +1560,54 @@ class MappingDriverV8(object): MappingDriver = manager.create_legacy_driver(MappingDriverV8) + + +@dependency.provider('shadow_users_api') +class ShadowUsersManager(manager.Manager): + """Default pivot point for the Shadow Users backend.""" + + driver_namespace = 'keystone.identity.shadow_users' + + def __init__(self): + super(ShadowUsersManager, self).__init__(CONF.shadow_users.driver) + + +@six.add_metaclass(abc.ABCMeta) +class ShadowUsersDriverV9(object): + """Interface description for an Shadow Users driver.""" + + @abc.abstractmethod + def create_federated_user(self, federated_dict): + """Create a new user with the federated identity + + :param dict federated_dict: Reference to the federated user + :param user_id: user ID for linking to the federated identity + :returns dict: Containing the user reference + + """ + raise exception.NotImplemented() + + @abc.abstractmethod + def get_federated_user(self, idp_id, protocol_id, unique_id): + """Returns the found user for the federated identity + + :param idp_id: The identity provider ID + :param protocol_id: The federation protocol ID + :param unique_id: The unique ID for the user + :returns dict: Containing the user reference + + """ + raise exception.NotImplemented() + + @abc.abstractmethod + def update_federated_user_display_name(self, idp_id, protocol_id, + unique_id, display_name): + """Updates federated user's display name if changed + + :param idp_id: The identity provider ID + :param protocol_id: The federation protocol ID + :param unique_id: The unique ID for the user + :param display_name: The user's display name + + """ + raise exception.NotImplemented() diff --git a/keystone-moon/keystone/identity/mapping_backends/sql.py b/keystone-moon/keystone/identity/mapping_backends/sql.py index 7ab4ef52..91b33dd7 100644 --- a/keystone-moon/keystone/identity/mapping_backends/sql.py +++ b/keystone-moon/keystone/identity/mapping_backends/sql.py @@ -23,7 +23,7 @@ class IDMapping(sql.ModelBase, sql.ModelDictMixin): public_id = sql.Column(sql.String(64), primary_key=True) domain_id = sql.Column(sql.String(64), nullable=False) local_id = sql.Column(sql.String(64), nullable=False) - # NOTE(henry-nash); Postgres requires a name to be defined for an Enum + # NOTE(henry-nash): Postgres requires a name to be defined for an Enum entity_type = sql.Column( sql.Enum(identity_mapping.EntityType.USER, identity_mapping.EntityType.GROUP, @@ -32,7 +32,7 @@ class IDMapping(sql.ModelBase, sql.ModelDictMixin): # Unique constraint to ensure you can't store more than one mapping to the # same underlying values __table_args__ = ( - sql.UniqueConstraint('domain_id', 'local_id', 'entity_type'), {}) + sql.UniqueConstraint('domain_id', 'local_id', 'entity_type'),) @dependency.requires('id_generator_api') @@ -45,27 +45,27 @@ class Mapping(identity.MappingDriverV8): # work if we hashed all the entries, even those that already generate # UUIDs, like SQL. Further, this would only work if the generation # algorithm was immutable (e.g. it had always been sha256). - session = sql.get_session() - query = session.query(IDMapping.public_id) - query = query.filter_by(domain_id=local_entity['domain_id']) - query = query.filter_by(local_id=local_entity['local_id']) - query = query.filter_by(entity_type=local_entity['entity_type']) - try: - public_ref = query.one() - public_id = public_ref.public_id - return public_id - except sql.NotFound: - return None + with sql.session_for_read() as session: + query = session.query(IDMapping.public_id) + query = query.filter_by(domain_id=local_entity['domain_id']) + query = query.filter_by(local_id=local_entity['local_id']) + query = query.filter_by(entity_type=local_entity['entity_type']) + try: + public_ref = query.one() + public_id = public_ref.public_id + return public_id + except sql.NotFound: + return None def get_id_mapping(self, public_id): - session = sql.get_session() - mapping_ref = session.query(IDMapping).get(public_id) - if mapping_ref: - return mapping_ref.to_dict() + with sql.session_for_read() as session: + mapping_ref = session.query(IDMapping).get(public_id) + if mapping_ref: + return mapping_ref.to_dict() def create_id_mapping(self, local_entity, public_id=None): entity = local_entity.copy() - with sql.transaction() as session: + with sql.session_for_write() as session: if public_id is None: public_id = self.id_generator_api.generate_public_ID(entity) entity['public_id'] = public_id @@ -74,24 +74,25 @@ class Mapping(identity.MappingDriverV8): return public_id def delete_id_mapping(self, public_id): - with sql.transaction() as session: + with sql.session_for_write() as session: try: session.query(IDMapping).filter( IDMapping.public_id == public_id).delete() - except sql.NotFound: + except sql.NotFound: # nosec # NOTE(morganfainberg): There is nothing to delete and nothing # to do. pass def purge_mappings(self, purge_filter): - session = sql.get_session() - query = session.query(IDMapping) - if 'domain_id' in purge_filter: - query = query.filter_by(domain_id=purge_filter['domain_id']) - if 'public_id' in purge_filter: - query = query.filter_by(public_id=purge_filter['public_id']) - if 'local_id' in purge_filter: - query = query.filter_by(local_id=purge_filter['local_id']) - if 'entity_type' in purge_filter: - query = query.filter_by(entity_type=purge_filter['entity_type']) - query.delete() + with sql.session_for_write() as session: + query = session.query(IDMapping) + if 'domain_id' in purge_filter: + query = query.filter_by(domain_id=purge_filter['domain_id']) + if 'public_id' in purge_filter: + query = query.filter_by(public_id=purge_filter['public_id']) + if 'local_id' in purge_filter: + query = query.filter_by(local_id=purge_filter['local_id']) + if 'entity_type' in purge_filter: + query = query.filter_by( + entity_type=purge_filter['entity_type']) + query.delete() diff --git a/keystone-moon/keystone/identity/shadow_backends/__init__.py b/keystone-moon/keystone/identity/shadow_backends/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/keystone-moon/keystone/identity/shadow_backends/sql.py b/keystone-moon/keystone/identity/shadow_backends/sql.py new file mode 100644 index 00000000..af5a995b --- /dev/null +++ b/keystone-moon/keystone/identity/shadow_backends/sql.py @@ -0,0 +1,73 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import uuid + +from keystone.common import sql +from keystone import exception +from keystone import identity +from keystone.identity.backends import sql as model + + +class ShadowUsers(identity.ShadowUsersDriverV9): + @sql.handle_conflicts(conflict_type='federated_user') + def create_federated_user(self, federated_dict): + user = { + 'id': uuid.uuid4().hex, + 'enabled': True + } + with sql.session_for_write() as session: + federated_ref = model.FederatedUser.from_dict(federated_dict) + user_ref = model.User.from_dict(user) + user_ref.federated_users.append(federated_ref) + session.add(user_ref) + return identity.filter_user(user_ref.to_dict()) + + def get_federated_user(self, idp_id, protocol_id, unique_id): + user_ref = self._get_federated_user(idp_id, protocol_id, unique_id) + return identity.filter_user(user_ref.to_dict()) + + def _get_federated_user(self, idp_id, protocol_id, unique_id): + """Returns the found user for the federated identity + + :param idp_id: The identity provider ID + :param protocol_id: The federation protocol ID + :param unique_id: The user's unique ID (unique within the IdP) + :returns User: Returns a reference to the User + + """ + with sql.session_for_read() as session: + query = session.query(model.User).outerjoin(model.LocalUser) + query = query.join(model.FederatedUser) + query = query.filter(model.FederatedUser.idp_id == idp_id) + query = query.filter(model.FederatedUser.protocol_id == + protocol_id) + query = query.filter(model.FederatedUser.unique_id == unique_id) + try: + user_ref = query.one() + except sql.NotFound: + raise exception.UserNotFound(user_id=unique_id) + return user_ref + + @sql.handle_conflicts(conflict_type='federated_user') + def update_federated_user_display_name(self, idp_id, protocol_id, + unique_id, display_name): + with sql.session_for_write() as session: + query = session.query(model.FederatedUser) + query = query.filter(model.FederatedUser.idp_id == idp_id) + query = query.filter(model.FederatedUser.protocol_id == + protocol_id) + query = query.filter(model.FederatedUser.unique_id == unique_id) + query = query.filter(model.FederatedUser.display_name != + display_name) + query.update({'display_name': display_name}) + return -- cgit 1.2.3-korg