aboutsummaryrefslogtreecommitdiffstats
path: root/keystone-moon/keystone/identity/backends
diff options
context:
space:
mode:
authorWuKong <rebirthmonkey@gmail.com>2015-06-30 18:47:29 +0200
committerWuKong <rebirthmonkey@gmail.com>2015-06-30 18:47:29 +0200
commitb8c756ecdd7cced1db4300935484e8c83701c82e (patch)
tree87e51107d82b217ede145de9d9d59e2100725bd7 /keystone-moon/keystone/identity/backends
parentc304c773bae68fb854ed9eab8fb35c4ef17cf136 (diff)
migrate moon code from github to opnfv
Change-Id: Ice53e368fd1114d56a75271aa9f2e598e3eba604 Signed-off-by: WuKong <rebirthmonkey@gmail.com>
Diffstat (limited to 'keystone-moon/keystone/identity/backends')
-rw-r--r--keystone-moon/keystone/identity/backends/__init__.py0
-rw-r--r--keystone-moon/keystone/identity/backends/ldap.py402
-rw-r--r--keystone-moon/keystone/identity/backends/sql.py314
3 files changed, 716 insertions, 0 deletions
diff --git a/keystone-moon/keystone/identity/backends/__init__.py b/keystone-moon/keystone/identity/backends/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/keystone-moon/keystone/identity/backends/__init__.py
diff --git a/keystone-moon/keystone/identity/backends/ldap.py b/keystone-moon/keystone/identity/backends/ldap.py
new file mode 100644
index 00000000..0f7ee450
--- /dev/null
+++ b/keystone-moon/keystone/identity/backends/ldap.py
@@ -0,0 +1,402 @@
+# Copyright 2012 OpenStack Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+from __future__ import absolute_import
+import uuid
+
+import ldap
+import ldap.filter
+from oslo_config import cfg
+from oslo_log import log
+import six
+
+from keystone import clean
+from keystone.common import driver_hints
+from keystone.common import ldap as common_ldap
+from keystone.common import models
+from keystone import exception
+from keystone.i18n import _
+from keystone import identity
+
+
+CONF = cfg.CONF
+LOG = log.getLogger(__name__)
+
+
+class Identity(identity.Driver):
+ def __init__(self, conf=None):
+ super(Identity, self).__init__()
+ if conf is None:
+ conf = CONF
+ self.user = UserApi(conf)
+ self.group = GroupApi(conf)
+
+ def default_assignment_driver(self):
+ return "keystone.assignment.backends.ldap.Assignment"
+
+ def is_domain_aware(self):
+ return False
+
+ def generates_uuids(self):
+ return False
+
+ # Identity interface
+
+ def authenticate(self, user_id, password):
+ try:
+ user_ref = self._get_user(user_id)
+ except exception.UserNotFound:
+ raise AssertionError(_('Invalid user / password'))
+ if not user_id or not password:
+ raise AssertionError(_('Invalid user / password'))
+ conn = None
+ try:
+ conn = self.user.get_connection(user_ref['dn'],
+ password, end_user_auth=True)
+ if not conn:
+ raise AssertionError(_('Invalid user / password'))
+ except Exception:
+ raise AssertionError(_('Invalid user / password'))
+ finally:
+ if conn:
+ conn.unbind_s()
+ return self.user.filter_attributes(user_ref)
+
+ def _get_user(self, user_id):
+ return self.user.get(user_id)
+
+ def get_user(self, user_id):
+ return self.user.get_filtered(user_id)
+
+ def list_users(self, hints):
+ return self.user.get_all_filtered(hints)
+
+ def get_user_by_name(self, user_name, domain_id):
+ # domain_id will already have been handled in the Manager layer,
+ # parameter left in so this matches the Driver specification
+ return self.user.filter_attributes(self.user.get_by_name(user_name))
+
+ # CRUD
+ def create_user(self, user_id, user):
+ self.user.check_allow_create()
+ user_ref = self.user.create(user)
+ return self.user.filter_attributes(user_ref)
+
+ def update_user(self, user_id, user):
+ self.user.check_allow_update()
+ old_obj = self.user.get(user_id)
+ if 'name' in user and old_obj.get('name') != user['name']:
+ raise exception.Conflict(_('Cannot change user name'))
+
+ if self.user.enabled_mask:
+ self.user.mask_enabled_attribute(user)
+ elif self.user.enabled_invert and not self.user.enabled_emulation:
+ # We need to invert the enabled value for the old model object
+ # to prevent the LDAP update code from thinking that the enabled
+ # values are already equal.
+ user['enabled'] = not user['enabled']
+ old_obj['enabled'] = not old_obj['enabled']
+
+ self.user.update(user_id, user, old_obj)
+ return self.user.get_filtered(user_id)
+
+ def delete_user(self, user_id):
+ self.user.check_allow_delete()
+ user = self.user.get(user_id)
+ user_dn = user['dn']
+ groups = self.group.list_user_groups(user_dn)
+ for group in groups:
+ self.group.remove_user(user_dn, group['id'], user_id)
+
+ if hasattr(user, 'tenant_id'):
+ self.project.remove_user(user.tenant_id, user_dn)
+ self.user.delete(user_id)
+
+ def create_group(self, group_id, group):
+ self.group.check_allow_create()
+ group['name'] = clean.group_name(group['name'])
+ return common_ldap.filter_entity(self.group.create(group))
+
+ def get_group(self, group_id):
+ return self.group.get_filtered(group_id)
+
+ def get_group_by_name(self, group_name, domain_id):
+ # domain_id will already have been handled in the Manager layer,
+ # parameter left in so this matches the Driver specification
+ return self.group.get_filtered_by_name(group_name)
+
+ def update_group(self, group_id, group):
+ self.group.check_allow_update()
+ if 'name' in group:
+ group['name'] = clean.group_name(group['name'])
+ return common_ldap.filter_entity(self.group.update(group_id, group))
+
+ def delete_group(self, group_id):
+ self.group.check_allow_delete()
+ return self.group.delete(group_id)
+
+ def add_user_to_group(self, user_id, group_id):
+ user_ref = self._get_user(user_id)
+ user_dn = user_ref['dn']
+ self.group.add_user(user_dn, group_id, user_id)
+
+ def remove_user_from_group(self, user_id, group_id):
+ user_ref = self._get_user(user_id)
+ user_dn = user_ref['dn']
+ self.group.remove_user(user_dn, group_id, user_id)
+
+ def list_groups_for_user(self, user_id, hints):
+ user_ref = self._get_user(user_id)
+ user_dn = user_ref['dn']
+ return self.group.list_user_groups_filtered(user_dn, hints)
+
+ def list_groups(self, hints):
+ return self.group.get_all_filtered(hints)
+
+ def list_users_in_group(self, group_id, hints):
+ users = []
+ for user_dn in self.group.list_group_users(group_id):
+ user_id = self.user._dn_to_id(user_dn)
+ try:
+ users.append(self.user.get_filtered(user_id))
+ except exception.UserNotFound:
+ LOG.debug(("Group member '%(user_dn)s' not found in"
+ " '%(group_id)s'. The user should be removed"
+ " from the group. The user will be ignored."),
+ dict(user_dn=user_dn, group_id=group_id))
+ return users
+
+ def check_user_in_group(self, user_id, group_id):
+ user_refs = self.list_users_in_group(group_id, driver_hints.Hints())
+ for x in user_refs:
+ if x['id'] == user_id:
+ break
+ else:
+ # Try to fetch the user to see if it even exists. This
+ # will raise a more accurate exception.
+ self.get_user(user_id)
+ raise exception.NotFound(_("User '%(user_id)s' not found in"
+ " group '%(group_id)s'") %
+ {'user_id': user_id,
+ 'group_id': group_id})
+
+
+# TODO(termie): turn this into a data object and move logic to driver
+class UserApi(common_ldap.EnabledEmuMixIn, common_ldap.BaseLdap):
+ DEFAULT_OU = 'ou=Users'
+ DEFAULT_STRUCTURAL_CLASSES = ['person']
+ DEFAULT_ID_ATTR = 'cn'
+ DEFAULT_OBJECTCLASS = 'inetOrgPerson'
+ NotFound = exception.UserNotFound
+ options_name = 'user'
+ attribute_options_names = {'password': 'pass',
+ 'email': 'mail',
+ 'name': 'name',
+ 'enabled': 'enabled',
+ 'default_project_id': 'default_project_id'}
+ immutable_attrs = ['id']
+
+ model = models.User
+
+ def __init__(self, conf):
+ super(UserApi, self).__init__(conf)
+ self.enabled_mask = conf.ldap.user_enabled_mask
+ self.enabled_default = conf.ldap.user_enabled_default
+ self.enabled_invert = conf.ldap.user_enabled_invert
+ self.enabled_emulation = conf.ldap.user_enabled_emulation
+
+ def _ldap_res_to_model(self, res):
+ obj = super(UserApi, self)._ldap_res_to_model(res)
+ if self.enabled_mask != 0:
+ enabled = int(obj.get('enabled', self.enabled_default))
+ obj['enabled'] = ((enabled & self.enabled_mask) !=
+ self.enabled_mask)
+ elif self.enabled_invert and not self.enabled_emulation:
+ # This could be a bool or a string. If it's a string,
+ # we need to convert it so we can invert it properly.
+ enabled = obj.get('enabled', self.enabled_default)
+ if isinstance(enabled, six.string_types):
+ if enabled.lower() == 'true':
+ enabled = True
+ else:
+ enabled = False
+ obj['enabled'] = not enabled
+ obj['dn'] = res[0]
+
+ return obj
+
+ def mask_enabled_attribute(self, values):
+ value = values['enabled']
+ values.setdefault('enabled_nomask', int(self.enabled_default))
+ if value != ((values['enabled_nomask'] & self.enabled_mask) !=
+ self.enabled_mask):
+ values['enabled_nomask'] ^= self.enabled_mask
+ values['enabled'] = values['enabled_nomask']
+ del values['enabled_nomask']
+
+ def create(self, values):
+ if self.enabled_mask:
+ orig_enabled = values['enabled']
+ self.mask_enabled_attribute(values)
+ elif self.enabled_invert and not self.enabled_emulation:
+ orig_enabled = values['enabled']
+ if orig_enabled is not None:
+ values['enabled'] = not orig_enabled
+ else:
+ values['enabled'] = self.enabled_default
+ values = super(UserApi, self).create(values)
+ if self.enabled_mask or (self.enabled_invert and
+ not self.enabled_emulation):
+ values['enabled'] = orig_enabled
+ return values
+
+ def get_filtered(self, user_id):
+ user = self.get(user_id)
+ return self.filter_attributes(user)
+
+ def get_all_filtered(self, hints):
+ query = self.filter_query(hints)
+ return [self.filter_attributes(user) for user in self.get_all(query)]
+
+ def filter_attributes(self, user):
+ return identity.filter_user(common_ldap.filter_entity(user))
+
+ def is_user(self, dn):
+ """Returns True if the entry is a user."""
+
+ # NOTE(blk-u): It's easy to check if the DN is under the User tree,
+ # but may not be accurate. A more accurate test would be to fetch the
+ # entry to see if it's got the user objectclass, but this could be
+ # really expensive considering how this is used.
+
+ return common_ldap.dn_startswith(dn, self.tree_dn)
+
+
+class GroupApi(common_ldap.BaseLdap):
+ DEFAULT_OU = 'ou=UserGroups'
+ DEFAULT_STRUCTURAL_CLASSES = []
+ DEFAULT_OBJECTCLASS = 'groupOfNames'
+ DEFAULT_ID_ATTR = 'cn'
+ DEFAULT_MEMBER_ATTRIBUTE = 'member'
+ NotFound = exception.GroupNotFound
+ options_name = 'group'
+ attribute_options_names = {'description': 'desc',
+ 'name': 'name'}
+ immutable_attrs = ['name']
+ model = models.Group
+
+ def _ldap_res_to_model(self, res):
+ model = super(GroupApi, self)._ldap_res_to_model(res)
+ model['dn'] = res[0]
+ return model
+
+ def __init__(self, conf):
+ super(GroupApi, self).__init__(conf)
+ self.member_attribute = (conf.ldap.group_member_attribute
+ or self.DEFAULT_MEMBER_ATTRIBUTE)
+
+ def create(self, values):
+ data = values.copy()
+ if data.get('id') is None:
+ data['id'] = uuid.uuid4().hex
+ if 'description' in data and data['description'] in ['', None]:
+ data.pop('description')
+ return super(GroupApi, self).create(data)
+
+ def delete(self, group_id):
+ if self.subtree_delete_enabled:
+ super(GroupApi, self).deleteTree(group_id)
+ else:
+ # TODO(spzala): this is only placeholder for group and domain
+ # role support which will be added under bug 1101287
+
+ group_ref = self.get(group_id)
+ group_dn = group_ref['dn']
+ if group_dn:
+ self._delete_tree_nodes(group_dn, ldap.SCOPE_ONELEVEL)
+ super(GroupApi, self).delete(group_id)
+
+ def update(self, group_id, values):
+ old_obj = self.get(group_id)
+ return super(GroupApi, self).update(group_id, values, old_obj)
+
+ def add_user(self, user_dn, group_id, user_id):
+ group_ref = self.get(group_id)
+ group_dn = group_ref['dn']
+ try:
+ super(GroupApi, self).add_member(user_dn, group_dn)
+ except exception.Conflict:
+ raise exception.Conflict(_(
+ 'User %(user_id)s is already a member of group %(group_id)s') %
+ {'user_id': user_id, 'group_id': group_id})
+
+ def remove_user(self, user_dn, group_id, user_id):
+ group_ref = self.get(group_id)
+ group_dn = group_ref['dn']
+ try:
+ super(GroupApi, self).remove_member(user_dn, group_dn)
+ except ldap.NO_SUCH_ATTRIBUTE:
+ raise exception.UserNotFound(user_id=user_id)
+
+ def list_user_groups(self, user_dn):
+ """Return a list of groups for which the user is a member."""
+
+ user_dn_esc = ldap.filter.escape_filter_chars(user_dn)
+ query = '(&(objectClass=%s)(%s=%s)%s)' % (self.object_class,
+ self.member_attribute,
+ user_dn_esc,
+ self.ldap_filter or '')
+ return self.get_all(query)
+
+ def list_user_groups_filtered(self, user_dn, hints):
+ """Return a filtered list of groups for which the user is a member."""
+
+ user_dn_esc = ldap.filter.escape_filter_chars(user_dn)
+ query = '(&(objectClass=%s)(%s=%s)%s)' % (self.object_class,
+ self.member_attribute,
+ user_dn_esc,
+ self.ldap_filter or '')
+ return self.get_all_filtered(hints, query)
+
+ def list_group_users(self, group_id):
+ """Return a list of user dns which are members of a group."""
+ group_ref = self.get(group_id)
+ group_dn = group_ref['dn']
+
+ try:
+ attrs = self._ldap_get_list(group_dn, ldap.SCOPE_BASE,
+ attrlist=[self.member_attribute])
+ except ldap.NO_SUCH_OBJECT:
+ raise self.NotFound(group_id=group_id)
+
+ users = []
+ for dn, member in attrs:
+ user_dns = member.get(self.member_attribute, [])
+ for user_dn in user_dns:
+ if self._is_dumb_member(user_dn):
+ continue
+ users.append(user_dn)
+ return users
+
+ def get_filtered(self, group_id):
+ group = self.get(group_id)
+ return common_ldap.filter_entity(group)
+
+ def get_filtered_by_name(self, group_name):
+ group = self.get_by_name(group_name)
+ return common_ldap.filter_entity(group)
+
+ def get_all_filtered(self, hints, query=None):
+ query = self.filter_query(hints, query)
+ return [common_ldap.filter_entity(group)
+ for group in self.get_all(query)]
diff --git a/keystone-moon/keystone/identity/backends/sql.py b/keystone-moon/keystone/identity/backends/sql.py
new file mode 100644
index 00000000..39868416
--- /dev/null
+++ b/keystone-moon/keystone/identity/backends/sql.py
@@ -0,0 +1,314 @@
+# Copyright 2012 OpenStack Foundation
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from oslo_config import cfg
+
+from keystone.common import sql
+from keystone.common import utils
+from keystone import exception
+from keystone.i18n import _
+from keystone import identity
+
+
+CONF = cfg.CONF
+
+
+class User(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'user'
+ attributes = ['id', 'name', 'domain_id', 'password', 'enabled',
+ 'default_project_id']
+ id = sql.Column(sql.String(64), primary_key=True)
+ name = sql.Column(sql.String(255), nullable=False)
+ domain_id = sql.Column(sql.String(64), nullable=False)
+ password = sql.Column(sql.String(128))
+ enabled = sql.Column(sql.Boolean)
+ extra = sql.Column(sql.JsonBlob())
+ default_project_id = sql.Column(sql.String(64))
+ # Unique constraint across two columns to create the separation
+ # rather than just only 'name' being unique
+ __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
+
+ def to_dict(self, include_extra_dict=False):
+ d = super(User, self).to_dict(include_extra_dict=include_extra_dict)
+ if 'default_project_id' in d and d['default_project_id'] is None:
+ del d['default_project_id']
+ return d
+
+
+class Group(sql.ModelBase, sql.DictBase):
+ __tablename__ = 'group'
+ attributes = ['id', 'name', 'domain_id', 'description']
+ id = sql.Column(sql.String(64), primary_key=True)
+ name = sql.Column(sql.String(64), nullable=False)
+ domain_id = sql.Column(sql.String(64), nullable=False)
+ description = sql.Column(sql.Text())
+ extra = sql.Column(sql.JsonBlob())
+ # Unique constraint across two columns to create the separation
+ # rather than just only 'name' being unique
+ __table_args__ = (sql.UniqueConstraint('domain_id', 'name'), {})
+
+
+class UserGroupMembership(sql.ModelBase, sql.DictBase):
+ """Group membership join table."""
+ __tablename__ = 'user_group_membership'
+ user_id = sql.Column(sql.String(64),
+ sql.ForeignKey('user.id'),
+ primary_key=True)
+ group_id = sql.Column(sql.String(64),
+ sql.ForeignKey('group.id'),
+ primary_key=True)
+
+
+class Identity(identity.Driver):
+ # NOTE(henry-nash): Override the __init__() method so as to take a
+ # config parameter to enable sql to be used as a domain-specific driver.
+ def __init__(self, conf=None):
+ super(Identity, self).__init__()
+
+ def default_assignment_driver(self):
+ return "keystone.assignment.backends.sql.Assignment"
+
+ @property
+ def is_sql(self):
+ return True
+
+ def _check_password(self, password, user_ref):
+ """Check the specified password against the data store.
+
+ Note that we'll pass in the entire user_ref in case the subclass
+ needs things like user_ref.get('name')
+ For further justification, please see the follow up suggestion at
+ https://blueprints.launchpad.net/keystone/+spec/sql-identiy-pam
+
+ """
+ return utils.check_password(password, user_ref.password)
+
+ # Identity interface
+ def authenticate(self, user_id, password):
+ session = sql.get_session()
+ user_ref = None
+ try:
+ user_ref = self._get_user(session, user_id)
+ except exception.UserNotFound:
+ raise AssertionError(_('Invalid user / password'))
+ if not self._check_password(password, user_ref):
+ raise AssertionError(_('Invalid user / password'))
+ return identity.filter_user(user_ref.to_dict())
+
+ # user crud
+
+ @sql.handle_conflicts(conflict_type='user')
+ def create_user(self, user_id, user):
+ user = utils.hash_user_password(user)
+ session = sql.get_session()
+ with session.begin():
+ user_ref = User.from_dict(user)
+ session.add(user_ref)
+ return identity.filter_user(user_ref.to_dict())
+
+ @sql.truncated
+ def list_users(self, hints):
+ session = sql.get_session()
+ query = session.query(User)
+ user_refs = sql.filter_limit_query(User, query, hints)
+ return [identity.filter_user(x.to_dict()) for x in user_refs]
+
+ def _get_user(self, session, user_id):
+ user_ref = session.query(User).get(user_id)
+ if not user_ref:
+ raise exception.UserNotFound(user_id=user_id)
+ return user_ref
+
+ def get_user(self, user_id):
+ session = sql.get_session()
+ return identity.filter_user(self._get_user(session, user_id).to_dict())
+
+ def get_user_by_name(self, user_name, domain_id):
+ session = sql.get_session()
+ query = session.query(User)
+ query = query.filter_by(name=user_name)
+ query = query.filter_by(domain_id=domain_id)
+ try:
+ user_ref = query.one()
+ except sql.NotFound:
+ raise exception.UserNotFound(user_id=user_name)
+ return identity.filter_user(user_ref.to_dict())
+
+ @sql.handle_conflicts(conflict_type='user')
+ def update_user(self, user_id, user):
+ session = sql.get_session()
+
+ with session.begin():
+ user_ref = self._get_user(session, user_id)
+ old_user_dict = user_ref.to_dict()
+ user = utils.hash_user_password(user)
+ for k in user:
+ old_user_dict[k] = user[k]
+ new_user = User.from_dict(old_user_dict)
+ for attr in User.attributes:
+ if attr != 'id':
+ setattr(user_ref, attr, getattr(new_user, attr))
+ user_ref.extra = new_user.extra
+ return identity.filter_user(user_ref.to_dict(include_extra_dict=True))
+
+ def add_user_to_group(self, user_id, group_id):
+ session = sql.get_session()
+ self.get_group(group_id)
+ self.get_user(user_id)
+ query = session.query(UserGroupMembership)
+ query = query.filter_by(user_id=user_id)
+ query = query.filter_by(group_id=group_id)
+ rv = query.first()
+ if rv:
+ return
+
+ with session.begin():
+ session.add(UserGroupMembership(user_id=user_id,
+ group_id=group_id))
+
+ def check_user_in_group(self, user_id, group_id):
+ session = sql.get_session()
+ self.get_group(group_id)
+ self.get_user(user_id)
+ query = session.query(UserGroupMembership)
+ query = query.filter_by(user_id=user_id)
+ query = query.filter_by(group_id=group_id)
+ if not query.first():
+ raise exception.NotFound(_("User '%(user_id)s' not found in"
+ " group '%(group_id)s'") %
+ {'user_id': user_id,
+ 'group_id': group_id})
+
+ def remove_user_from_group(self, user_id, group_id):
+ session = sql.get_session()
+ # We don't check if user or group are still valid and let the remove
+ # be tried anyway - in case this is some kind of clean-up operation
+ query = session.query(UserGroupMembership)
+ query = query.filter_by(user_id=user_id)
+ query = query.filter_by(group_id=group_id)
+ membership_ref = query.first()
+ if membership_ref is None:
+ # Check if the group and user exist to return descriptive
+ # exceptions.
+ self.get_group(group_id)
+ self.get_user(user_id)
+ raise exception.NotFound(_("User '%(user_id)s' not found in"
+ " group '%(group_id)s'") %
+ {'user_id': user_id,
+ 'group_id': group_id})
+ with session.begin():
+ session.delete(membership_ref)
+
+ def list_groups_for_user(self, user_id, hints):
+ # TODO(henry-nash) We could implement full filtering here by enhancing
+ # the join below. However, since it is likely to be a fairly rare
+ # occurrence to filter on more than the user_id already being used
+ # here, this is left as future enhancement and until then we leave
+ # it for the controller to do for us.
+ session = sql.get_session()
+ self.get_user(user_id)
+ query = session.query(Group).join(UserGroupMembership)
+ query = query.filter(UserGroupMembership.user_id == user_id)
+ return [g.to_dict() for g in query]
+
+ def list_users_in_group(self, group_id, hints):
+ # TODO(henry-nash) We could implement full filtering here by enhancing
+ # the join below. However, since it is likely to be a fairly rare
+ # occurrence to filter on more than the group_id already being used
+ # here, this is left as future enhancement and until then we leave
+ # it for the controller to do for us.
+ session = sql.get_session()
+ self.get_group(group_id)
+ query = session.query(User).join(UserGroupMembership)
+ query = query.filter(UserGroupMembership.group_id == group_id)
+
+ return [identity.filter_user(u.to_dict()) for u in query]
+
+ def delete_user(self, user_id):
+ session = sql.get_session()
+
+ with session.begin():
+ ref = self._get_user(session, user_id)
+
+ q = session.query(UserGroupMembership)
+ q = q.filter_by(user_id=user_id)
+ q.delete(False)
+
+ session.delete(ref)
+
+ # group crud
+
+ @sql.handle_conflicts(conflict_type='group')
+ def create_group(self, group_id, group):
+ session = sql.get_session()
+ with session.begin():
+ ref = Group.from_dict(group)
+ session.add(ref)
+ return ref.to_dict()
+
+ @sql.truncated
+ def list_groups(self, hints):
+ session = sql.get_session()
+ query = session.query(Group)
+ refs = sql.filter_limit_query(Group, query, hints)
+ return [ref.to_dict() for ref in refs]
+
+ def _get_group(self, session, group_id):
+ ref = session.query(Group).get(group_id)
+ if not ref:
+ raise exception.GroupNotFound(group_id=group_id)
+ return ref
+
+ def get_group(self, group_id):
+ session = sql.get_session()
+ return self._get_group(session, group_id).to_dict()
+
+ def get_group_by_name(self, group_name, domain_id):
+ session = sql.get_session()
+ query = session.query(Group)
+ query = query.filter_by(name=group_name)
+ query = query.filter_by(domain_id=domain_id)
+ try:
+ group_ref = query.one()
+ except sql.NotFound:
+ raise exception.GroupNotFound(group_id=group_name)
+ return group_ref.to_dict()
+
+ @sql.handle_conflicts(conflict_type='group')
+ def update_group(self, group_id, group):
+ session = sql.get_session()
+
+ with session.begin():
+ ref = self._get_group(session, group_id)
+ old_dict = ref.to_dict()
+ for k in group:
+ old_dict[k] = group[k]
+ new_group = Group.from_dict(old_dict)
+ for attr in Group.attributes:
+ if attr != 'id':
+ setattr(ref, attr, getattr(new_group, attr))
+ ref.extra = new_group.extra
+ return ref.to_dict()
+
+ def delete_group(self, group_id):
+ session = sql.get_session()
+
+ with session.begin():
+ ref = self._get_group(session, group_id)
+
+ q = session.query(UserGroupMembership)
+ q = q.filter_by(group_id=group_id)
+ q.delete(False)
+
+ session.delete(ref)